[MLX] Gemma4-31B ondevice sampling#20561
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20561
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| print("Done.") | ||
|
|
||
|
|
||
| class _MLXSampleWrapper(nn.Module): |
There was a problem hiding this comment.
Ideally SamplingHead from backends/mlx/llm would be directly usable. The fact that we keep defining wrappers suggest that maybe it should be refactored to work?
In terms of shape, I expect sampling to work on (B, vocab) and return (B) (tokens).
I don't see a case for (B, S, vocab).
There was a problem hiding this comment.
I've made SamplingHead operate on (B, vocab) → (B) so it's model-agnostic, but it still needs a positional adapter for torch.export since its sampling args are keyword-only.
SamplingHead wraps the whole model (self.model(*args)) behind a generic, keyword-only forward(*args, temperature, seed, …), so I think each model needs a small adapter that maps its fixed forward(tokens, input_pos, temperature, top_p, seed) onto that signature.
I checked the existing heads and they avoid this by being tensor layers, lm_head (nn.Linear), FlowMatchingHead take a tensor in tensor out and are called inside the model's forward, so they ride on the model's already-positional export entry point. Would we want to change SamplingHead to a tensor head like this? The CUDA path avoids it too with a plain free function sampler.py:sample(logits, temperature) is called inside model.forward (return sample(logits, temperature)).
There was a problem hiding this comment.
What if we did:
def forward(self, *args, **kwargs):
*model_args, temperature, top_k, top_p, seed = args
logits = self.model(*model_args, **kwargs) # [B, V]
return torch.ops.mlx.sample(last, temperature, top_k, top_p, seed)
Then in Qwen code (and here), we could just do:
model = SamplingHead(model)
example_args = (example_tokens, example_input_pos, temp, top_k, top_p, seed)
With no wrapper?
There was a problem hiding this comment.
This should work, thank you!
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 18s —— View job Code review
Overall this is a clean, well-guarded port of the Qwen3.5 MoE on-device sampling pattern (#20545). The runtime guards, range-checks, and the per-token seed schedule are consistent with that PR. A few notes below. On @metascroy's question: why not reuse
|
|
|
||
| Result<DecodeResult> decode_one(const SamplingConfig& sampling) override { | ||
| if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { | ||
| if (sampling.top_k != 0) { |
There was a problem hiding this comment.
I think topk is supported now?
There was a problem hiding this comment.
Yes, I will update this.
Summary
Lets the MLX-exported Gemma 4 31B model sample the next token on-device instead of returning logits for host-side sampling. Sampling is opt-in at export (--sample); temperature, top_k, top_p, and seed are runtime inputs, and the runner increments the seed per token.
Changes