transformers
bdca59aa - Enable kernels-community/metal-flash-sdpa on MPS (#45974)

Commit
33 days ago
Enable kernels-community/metal-flash-sdpa on MPS (#45974) * Enable kernels-community/metal-flash-sdpa on MPS via generate_batch Two small fixes so `attn_implementation="kernels-community/metal-flash-sdpa"` works end-to-end on Apple Silicon (`generate` and `generate_batch`): * `modeling_flash_attention_utils._flash_attention_forward`: the "no padding" branch unconditionally called `flash_fn`, which is `None` for varlen-only kernels (the metal kernel only ships `flash_attn_varlen_func`). Synthesize `cu_seqlens` for the dense batched layout and route through `flash_varlen_fn` in that case. `.contiguous()` before reshape is required: the cached K/V (post-transpose) is non-contiguous and the Metal kernel reads garbage off it during decode, producing nonsense tokens. * `continuous_batching/requests.get_device_and_memory_breakdown`: on MPS, `torch.mps.driver_allocated_memory()` returns bytes currently held by the Metal driver (≈0 right after process start), not the total. Use `recommended_max_memory()` for total and `current_allocated_memory()` for the running allocation. Without this, `infer_num_blocks_and_max_batch_tokens` either returns a negative `num_blocks` or refuses to allocate, so `generate_batch` was unusable on MPS regardless of the chosen attention. Bench (gsm8k 100 samples, Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch): impl time(s) tok/s acc sdpa 149.33 158.4 30/100 kernels-community/metal-flash-sdpa 89.78 256.0 32/100 1.66x speedup, accuracy within noise. * serve: auto-select metal-flash-sdpa attention on MPS When `transformers serve` runs on Apple Silicon (`--device auto` or `mps`) with `kernels` installed and no explicit `--attn-implementation` flag, default the attention to `kernels-community/metal-flash-sdpa` instead of plain SDPA. On the 100-sample gsm8k benchmark (Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch) it's a 1.66x throughput improvement (158 -> 256 tok/s) with token-for-token parity for greedy decoding. Users who don't want it can opt out with `--attn-implementation sdpa`. Help text on the `--attn-implementation` flag also now lists the kernels-hub syntax explicitly. * Revert cu_seqlens synthesis in _flash_attention_forward Don't build cu_seqlens on the fly inside the modeling forward — the non-padding `else` branch can stay as a NoneType failure for varlen-only kernels. Callers that need varlen (continuous batching, padding-free training) go through `paged_attention_forward` or the explicit `cu_seq_lens_*` kwarg path, both of which already supply their own cumulative lengths. Companion kernel change: dropping `flash_attn_func` from kernels-community/metal-flash-sdpa for the same reason (PR #3). * serve: pin metal-flash-sdpa to the PR revision until it lands on main The published `main` of `kernels-community/metal-flash-sdpa` predates the MPS dispatch hardening (contiguity, int32 cast, alias clone, MPS encoder flush) that this integration depends on. Pinning to the open PR's HEAD commit so the auto-default actually works end-to-end out of the box. Drop / bump this constant when the upstream PR merges: https://huggingface.co/kernels-community/metal-flash-sdpa/discussions/3 * fix
Author
Parents
Loading