Skip to content

[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583

Open
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/75/basefrom
gh/JulianCloudNTH/75/head
Open

[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/75/basefrom
gh/JulianCloudNTH/75/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.

Problem: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (maxComputeWorkgroupsPerDimension, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill.

Solution: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from @builtin(num_workgroups).

  • Before: compute_1d_workgroup_count throws if count > limit; dispatch (count, 1, 1).
  • After: compute_2d_workgroup_count returns {count, 1} (fast path) or {limit, div_up(count, limit)}; dispatch (x, y, 1).

Implementation:

  • WgCount + pure fold_workgroup_count_2d + compute_2d_workgroup_count in WebGPUUtils.h (device-free, unit-testable; queried_max_workgroups factored out of the 1D path)
  • WebGPUDispatch.workgroup_count_y (default 1, declared last so existing aggregate inits are unchanged); both dispatchWorkgroups calls + the profiling record pass (x, y, 1)
  • Per-kernel in-shader reconstruction: thread-form idx = gid.x + gid.y*(num_workgroups.x*wg_size) (QK/AV/add); row-form row_idx = wid.x + wid.y*num_workgroups.x (softmax — keeps a valid predicate, not an early return, so workgroupBarrier()s stay uniform)
  • Sdpa.cpp: QK/softmax/AV counts via the 2D helper; the dynamic-input_pos resize hook recomputes both x and y for QK
  • Mirrors Vulkan dispatch (Vulkan itself does not guard the per-dim limit)

Constraints:

  • y=1 fast path keeps every non-folded dispatch byte-identical to the prior 1D path
  • Scope = prefill path only; rms_norm/embedding/lm_head/update_cache are row/token-indexed and never hit the cap, so they keep the 1D path
  • Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the uint32 element guard fires first at S~11585)

Co-authored-with: Claude Code.
@exported-using-ghexport

Differential Revision: D109517684

Differential Revision: D109517684

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20583

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit f98c501 with merge base db5f1ba (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 28, 2026
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review and check for any areas or opportunities for modularization

@claude

claude Bot commented Jun 29, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 3m 55s —— View job


Review: WebGPU 2D compute dispatch (#20583)

  • Read the full diff
  • Review correctness & WGSL shader changes
  • Assess modularization opportunities
  • Post review

Solid, well-scoped change. The y=1 fast path keeping non-folded dispatches byte-identical is the right call, the device-free fold_workgroup_count_2d is genuinely unit-testable, and keeping a valid predicate (not an early return) in softmax to preserve workgroupBarrier() uniformity is exactly correct WGSL — that's the subtle bug this avoids. Correctness of the index reconstruction checks out: for the thread-form shaders num_workgroups.x == max_count and lid.y == 0 (workgroup_size y is 1), so idx = gid.x + gid.y*(num_workgroups.x*wg_size) reconstructs the linear thread id, and the idx >= total guard absorbs the over-dispatch from rounding y up. 👍

🔴 Blocking: lint failure (CI red)

backends/webgpu/runtime/ops/add/BinaryOp.cpp:167 — clang-format wants the add_resize lambda capture list reflowed (it fits on fewer lines now). This is the only CI failure. Just run:

lintrunner -a

Fix this →

Modularization opportunities

1. The WebGPUDispatch aggregate-init is the main smell. To avoid touching the ~10 unchanged call sites, workgroup_count_y was declared after kernel_name (WebGPUGraph.h:51), which forces the awkward positional init in BinaryOp.cpp:162:

graph.add_dispatch({pipeline, bind_group, workgroup_count.x, "", workgroup_count.y});
//                                                            ^^ kernel_name placeholder

A field that's logically adjacent to workgroup_count_x now sits two members away purely for init-ordering, and callers must thread an empty-string placeholder. Cleaner options:

  • Add an add_dispatch overload taking utils::WgCount directly, so the binary-op path reads graph.add_dispatch({pipeline, bind_group, wgc}) with no placeholder; or
  • Keep x/y adjacent in the struct and update the other call sites (they default y=1 trivially).

2. dispatch.x = wgc.x; dispatch.y = wgc.y; is repeated 4× across the resize hooks (BinaryOp.cpp:186-187, Sdpa.cpp:643-644,651-652,661-662). A one-liner setter removes the copy-paste and the risk of updating one without the other:

inline void set_workgroup_count(WebGPUDispatch& d, utils::WgCount c) {
  d.workgroup_count_x = c.x;
  d.workgroup_count_y = c.y;
}

3. build_dispatch in Sdpa.cpp gained a separate workgroup_count_y param (between workgroup_count_x and wg_size). Passing a single utils::WgCount would shrink each of the three call sites (QK/softmax/AV) and the update_cache call (which passes a literal 1) and keep the dispatch geometry as one value.

4. Minor: reuse div_up. compute_2d_workgroup_count (WebGPUUtils.h:104) and fold_workgroup_count_2d:73 open-code (a + b - 1) / b, while compute_1d_workgroup_count:88 calls div_up. Use div_up in both for consistency.

5. Shader-side duplication (noting, not necessarily actionable). The thread-form reconstruction let idx = gid.x + gid.y * (num_workgroups.x * wg_size); is now identical in binary_add.wgsl, sdpa_compute_attn_weights.wgsl, and sdpa_compute_out.wgsl. WGSL has no include mechanism in this codebase, so a shared snippet isn't free — but if more ops adopt the 2D fold, a generated prelude string concatenated at pipeline-build time would centralize it. Fine to leave as-is for this PR's scope.

Nits / questions

  • num_workgroups.x * wg_size overflow. For add, wg_size=256 and num_workgroups.x ≤ 65535~16.7M, well within u32. The PR notes the uint32 element guard fires first at S~11585, so this is safe today — but the product is implicit; if a future op uses a large wg_size with a near-limit x it could overflow silently. A brief comment in the shaders stating the safe envelope would help the next reader.
  • queried_max_workgroups is called per dispatch (build-time, not hot path), each doing a wgpuDeviceGetLimits. Negligible, just flagging it isn't cached.

None of the modularization items are blockers — only the lint fix is required to turn CI green. Items 1–3 would meaningfully reduce the positional-init fragility if you want to fold them in.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants