[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583
[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583JulianCloudNTH wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20583
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit f98c501 with merge base db5f1ba ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review and check for any areas or opportunities for modularization |
|
Claude finished @JulianCloudNTH's task in 3m 55s —— View job Review: WebGPU 2D compute dispatch (#20583)
Solid, well-scoped change. The 🔴 Blocking: lint failure (CI red)
lintrunner -aModularization opportunities1. The graph.add_dispatch({pipeline, bind_group, workgroup_count.x, "", workgroup_count.y});
// ^^ kernel_name placeholderA field that's logically adjacent to
2. inline void set_workgroup_count(WebGPUDispatch& d, utils::WgCount c) {
d.workgroup_count_x = c.x;
d.workgroup_count_y = c.y;
}3. 4. Minor: reuse 5. Shader-side duplication (noting, not necessarily actionable). The thread-form reconstruction Nits / questions
None of the modularization items are blockers — only the lint fix is required to turn CI green. Items 1–3 would meaningfully reduce the positional-init fragility if you want to fold them in. |
Stack from ghstack (oldest at bottom):
Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.
Problem: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (
maxComputeWorkgroupsPerDimension, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill.Solution: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from
@builtin(num_workgroups).compute_1d_workgroup_countthrows ifcount > limit; dispatch(count, 1, 1).compute_2d_workgroup_countreturns{count, 1}(fast path) or{limit, div_up(count, limit)}; dispatch(x, y, 1).Implementation:
WgCount+ purefold_workgroup_count_2d+compute_2d_workgroup_countinWebGPUUtils.h(device-free, unit-testable;queried_max_workgroupsfactored out of the 1D path)WebGPUDispatch.workgroup_count_y(default 1, declared last so existing aggregate inits are unchanged); bothdispatchWorkgroupscalls + the profiling record pass(x, y, 1)idx = gid.x + gid.y*(num_workgroups.x*wg_size)(QK/AV/add); row-formrow_idx = wid.x + wid.y*num_workgroups.x(softmax — keeps avalidpredicate, not an early return, soworkgroupBarrier()s stay uniform)Sdpa.cpp: QK/softmax/AV counts via the 2D helper; the dynamic-input_posresize hook recomputes both x and y for QKConstraints:
y=1fast path keeps every non-folded dispatch byte-identical to the prior 1D pathrms_norm/embedding/lm_head/update_cacheare row/token-indexed and never hit the cap, so they keep the 1D pathuint32element guard fires first at S~11585)Co-authored-with: Claude Code.
@exported-using-ghexport
Differential Revision: D109517684
Differential Revision: D109517684