From a45d9baea8ebcf9d9977392aca7fe2d1893b28ba Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Tue, 30 Jun 2026 10:25:58 -0700 Subject: [PATCH] embedding_q4gsw: support is_linear_weight packing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The WebGPU `embedding_q4gsw` op failed to load (`Error::DelegateInvalidCompatibility`, error 48): `WebGPUGraph::build()` threw `is_linear_weight=true is unsupported` and `WebGPUBackend.cpp` returned error 48 at `load_forward`. The `quantized_decomposed.embedding_4bit` and torchao fusions in `backends/vulkan/patterns/quantized_embedding.py` repack the embedding weight into the 4-bit linear-layer nibble convention (low nibble = even dim, high nibble = odd dim) and emit `et_vk.embedding_q4gsw.default(..., is_linear_weight=true)` — always on OSS main, and for tied (embedding/LM-head-shared) weights via `_detect_tied_linear_weight`. The Vulkan runtime supports both packings via its compile-time `_linear_weight` shader variant; the WebGPU runtime rejected `true` outright, so any model whose embedding took that path failed to delegate. This teaches the WebGPU runtime both packings. The handler now forwards `is_linear_weight` to the shader through the uniform (the spare `_pad` field is repurposed; struct size unchanged), and `embedding_q4gsw.wgsl` selects the nibble with `use_high = is_even != is_linear_weight`. The `is_linear_weight=false` path is byte-identical to before. Differential Revision: D110211746 --- .../runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp | 10 ++++------ .../runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl | 10 ++++++---- .../ops/embedding_q4gsw/embedding_q4gsw_wgsl.h | 12 +++++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp index 5801b650f27..d24c693e486 100644 --- a/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp @@ -30,7 +30,7 @@ struct EmbeddingParams { uint32_t groups_per_row; uint32_t bytes_per_row; uint32_t total_blocks; - uint32_t _pad; + uint32_t is_linear_weight; }; static_assert( sizeof(EmbeddingParams) == 32, @@ -60,7 +60,8 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { const auto& indices = graph.get_tensor(indices_id); const auto& out = graph.get_tensor(out_id); - // Only the flat weight path is supported (linear-block unsupported). + // is_linear_weight selects the nibble packing (false: even dim = high nibble; + // true: even dim = low nibble). The shader handles both via a uniform. bool is_linear = false; if (graph.get_value_type(is_linear_weight_id) == WebGPUGraph::ValueType::Bool) { @@ -73,10 +74,6 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { throw std::runtime_error( "WebGPU embedding_q4gsw: is_linear_weight must be Bool or Int"); } - if (is_linear) { - throw std::runtime_error( - "WebGPU embedding_q4gsw: is_linear_weight=true is unsupported"); - } if (weight.dims.size() < 2 || scales.dims.size() < 2 || out.dims.empty() || indices.dims.empty()) { @@ -150,6 +147,7 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { params.groups_per_row = groups_per_row; params.bytes_per_row = bytes_per_row; params.total_blocks = static_cast(total_blocks); + params.is_linear_weight = is_linear ? 1u : 0u; WGPUBufferDescriptor uniform_desc = {}; uniform_desc.size = sizeof(EmbeddingParams); diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl index f16f3760d1c..fecb5f1e28a 100644 --- a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl @@ -11,7 +11,7 @@ struct Params { groups_per_row: u32, bytes_per_row: u32, total_blocks: u32, - _pad: u32, + is_linear_weight: u32, } @group(0) @binding(4) var params: Params; @@ -37,11 +37,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let byte_idx = row_byte_base + (dim >> 1u); let word = t_weight[byte_idx >> 2u]; let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; + // Nibble packing depends on is_linear_weight: non-linear maps even dim -> + // high nibble / odd -> low; linear maps even -> low / odd -> high. var nib: u32; - if ((dim & 1u) == 0u) { - nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble + if (((dim & 1u) == 0u) != (params.is_linear_weight != 0u)) { + nib = (b >> 4u) & 0x0Fu; // high nibble } else { - nib = b & 0x0Fu; // odd dim -> low nibble + nib = b & 0x0Fu; // low nibble } let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] let scale = t_scales[token * params.groups_per_row + dim / params.group_size]; diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h index e44c06a2ac5..db26795a021 100644 --- a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from embedding_q4gsw.wgsl - DO NOT EDIT. -// wgsl-sha256: 1fec9ed315696a88bb7db6c16454fc80e08ff73b0e39720b54515fda4ee1ef7c +// wgsl-sha256: 94da1061b49b62556a79020182a4989439a7c51f919e83d577536c5b6d25f487 inline constexpr const char* kEmbeddingQ4gswWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_indices: array; @@ -28,7 +28,7 @@ struct Params { groups_per_row: u32, bytes_per_row: u32, total_blocks: u32, - _pad: u32, + is_linear_weight: u32, } @group(0) @binding(4) var params: Params; @@ -54,11 +54,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let byte_idx = row_byte_base + (dim >> 1u); let word = t_weight[byte_idx >> 2u]; let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; + // Nibble packing depends on is_linear_weight: non-linear maps even dim -> + // high nibble / odd -> low; linear maps even -> low / odd -> high. var nib: u32; - if ((dim & 1u) == 0u) { - nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble + if (((dim & 1u) == 0u) != (params.is_linear_weight != 0u)) { + nib = (b >> 4u) & 0x0Fu; // high nibble } else { - nib = b & 0x0Fu; // odd dim -> low nibble + nib = b & 0x0Fu; // low nibble } let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] let scale = t_scales[token * params.groups_per_row + dim / params.group_size];