Enable kernels-community/metal-flash-sdpa on MPS (#45974)
* Enable kernels-community/metal-flash-sdpa on MPS via generate_batch
Two small fixes so `attn_implementation="kernels-community/metal-flash-sdpa"`
works end-to-end on Apple Silicon (`generate` and `generate_batch`):
* `modeling_flash_attention_utils._flash_attention_forward`: the "no padding"
branch unconditionally called `flash_fn`, which is `None` for varlen-only
kernels (the metal kernel only ships `flash_attn_varlen_func`). Synthesize
`cu_seqlens` for the dense batched layout and route through `flash_varlen_fn`
in that case. `.contiguous()` before reshape is required: the cached
K/V (post-transpose) is non-contiguous and the Metal kernel reads garbage
off it during decode, producing nonsense tokens.
* `continuous_batching/requests.get_device_and_memory_breakdown`: on MPS,
`torch.mps.driver_allocated_memory()` returns bytes currently held by the
Metal driver (≈0 right after process start), not the total. Use
`recommended_max_memory()` for total and `current_allocated_memory()` for
the running allocation. Without this, `infer_num_blocks_and_max_batch_tokens`
either returns a negative `num_blocks` or refuses to allocate, so
`generate_batch` was unusable on MPS regardless of the chosen attention.
Bench (gsm8k 100 samples, Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch):
impl time(s) tok/s acc
sdpa 149.33 158.4 30/100
kernels-community/metal-flash-sdpa 89.78 256.0 32/100
1.66x speedup, accuracy within noise.
* serve: auto-select metal-flash-sdpa attention on MPS
When `transformers serve` runs on Apple Silicon (`--device auto` or `mps`)
with `kernels` installed and no explicit `--attn-implementation` flag, default
the attention to `kernels-community/metal-flash-sdpa` instead of plain SDPA.
On the 100-sample gsm8k benchmark (Qwen2.5-0.5B-Instruct, MPS fp16,
generate_batch) it's a 1.66x throughput improvement (158 -> 256 tok/s) with
token-for-token parity for greedy decoding. Users who don't want it can opt
out with `--attn-implementation sdpa`.
Help text on the `--attn-implementation` flag also now lists the kernels-hub
syntax explicitly.
* Revert cu_seqlens synthesis in _flash_attention_forward
Don't build cu_seqlens on the fly inside the modeling forward — the
non-padding `else` branch can stay as a NoneType failure for varlen-only
kernels. Callers that need varlen (continuous batching, padding-free
training) go through `paged_attention_forward` or the explicit
`cu_seq_lens_*` kwarg path, both of which already supply their own
cumulative lengths.
Companion kernel change: dropping `flash_attn_func` from
kernels-community/metal-flash-sdpa for the same reason (PR #3).
* serve: pin metal-flash-sdpa to the PR revision until it lands on main
The published `main` of `kernels-community/metal-flash-sdpa` predates the MPS
dispatch hardening (contiguity, int32 cast, alias clone, MPS encoder flush)
that this integration depends on. Pinning to the open PR's HEAD commit so the
auto-default actually works end-to-end out of the box.
Drop / bump this constant when the upstream PR merges:
https://huggingface.co/kernels-community/metal-flash-sdpa/discussions/3
* fix