Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct EmbeddingParams {
uint32_t groups_per_row;
uint32_t bytes_per_row;
uint32_t total_blocks;
uint32_t _pad;
uint32_t is_linear_weight;
};
static_assert(
sizeof(EmbeddingParams) == 32,
Expand Down Expand Up @@ -60,7 +60,8 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const auto& indices = graph.get_tensor(indices_id);
const auto& out = graph.get_tensor(out_id);

// Only the flat weight path is supported (linear-block unsupported).
// is_linear_weight selects the nibble packing (false: even dim = high nibble;
// true: even dim = low nibble). The shader handles both via a uniform.
bool is_linear = false;
if (graph.get_value_type(is_linear_weight_id) ==
WebGPUGraph::ValueType::Bool) {
Expand All @@ -73,10 +74,6 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
throw std::runtime_error(
"WebGPU embedding_q4gsw: is_linear_weight must be Bool or Int");
}
if (is_linear) {
throw std::runtime_error(
"WebGPU embedding_q4gsw: is_linear_weight=true is unsupported");
}

if (weight.dims.size() < 2 || scales.dims.size() < 2 || out.dims.empty() ||
indices.dims.empty()) {
Expand Down Expand Up @@ -150,6 +147,7 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
params.groups_per_row = groups_per_row;
params.bytes_per_row = bytes_per_row;
params.total_blocks = static_cast<uint32_t>(total_blocks);
params.is_linear_weight = is_linear ? 1u : 0u;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(EmbeddingParams);
Expand Down
10 changes: 6 additions & 4 deletions backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct Params {
groups_per_row: u32,
bytes_per_row: u32,
total_blocks: u32,
_pad: u32,
is_linear_weight: u32,
}
@group(0) @binding(4) var<uniform> params: Params;

Expand All @@ -37,11 +37,13 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let byte_idx = row_byte_base + (dim >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
// Nibble packing depends on is_linear_weight: non-linear maps even dim ->
// high nibble / odd -> low; linear maps even -> low / odd -> high.
var nib: u32;
if ((dim & 1u) == 0u) {
nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble
if (((dim & 1u) == 0u) != (params.is_linear_weight != 0u)) {
nib = (b >> 4u) & 0x0Fu; // high nibble
} else {
nib = b & 0x0Fu; // odd dim -> low nibble
nib = b & 0x0Fu; // low nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[token * params.groups_per_row + dim / params.group_size];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from embedding_q4gsw.wgsl - DO NOT EDIT.
// wgsl-sha256: 1fec9ed315696a88bb7db6c16454fc80e08ff73b0e39720b54515fda4ee1ef7c
// wgsl-sha256: 94da1061b49b62556a79020182a4989439a7c51f919e83d577536c5b6d25f487
inline constexpr const char* kEmbeddingQ4gswWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_indices: array<i32>;
Expand All @@ -28,7 +28,7 @@ struct Params {
groups_per_row: u32,
bytes_per_row: u32,
total_blocks: u32,
_pad: u32,
is_linear_weight: u32,
}
@group(0) @binding(4) var<uniform> params: Params;

Expand All @@ -54,11 +54,13 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let byte_idx = row_byte_base + (dim >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
// Nibble packing depends on is_linear_weight: non-linear maps even dim ->
// high nibble / odd -> low; linear maps even -> low / odd -> high.
var nib: u32;
if ((dim & 1u) == 0u) {
nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble
if (((dim & 1u) == 0u) != (params.is_linear_weight != 0u)) {
nib = (b >> 4u) & 0x0Fu; // high nibble
} else {
nib = b & 0x0Fu; // odd dim -> low nibble
nib = b & 0x0Fu; // low nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[token * params.groups_per_row + dim / params.group_size];
Expand Down
Loading