DeepSpeed
b919284a - Mixed-precision: per-policy param/buffer dtype cast (preserve fp32 buffers) (#8066)

Commit
5 days ago
Mixed-precision: per-policy param/buffer dtype cast (preserve fp32 buffers) (#8066) ## Summary - Add `data_types.param_dtype` and `data_types.buffer_dtype` (both default `None`), mirroring FSDP `MixedPrecisionPolicy`. - Replace the blanket `module.half()` / `module.bfloat16()` in `_configure_distributed_model` with a targeted cast: parameters go to `param_dtype`; floating buffers keep their loaded dtype unless `buffer_dtype` is explicitly set. ## Motivation The blanket cast downcasts every floating buffer, including the rotary `inv_freq` buffer that HF/FSDP2 keep in fp32. On long contexts the bf16 `inv_freq` loses precision, RoPE angles drift, and logits/grads diverge from the FSDP2 reference. Preserving fp32 buffers by default fixes this; `buffer_dtype` is the escape hatch to reproduce the legacy behavior. ## Behavior - `param_dtype` unset -> derived from the fp16/bf16 enabled flag (legacy param behavior). - `buffer_dtype` unset -> buffers keep their loaded dtype (e.g. fp32 `inv_freq`). - `buffer_dtype` set -> buffers force-cast (legacy blanket-cast parity). ## Test plan - [ ] `param_dtype=bf16`, `buffer_dtype` unset -> params bf16, `inv_freq` stays fp32. - [ ] `buffer_dtype=bf16` -> buffers downcast (legacy parity). - [ ] bf16/fp16 run with neither key set behaves as before except fp32 buffers preserved. - [ ] 8B / 32B ZeRO-3 long-context run -> grad_norm tracks the FSDP2 reference. Made with [Cursor](https://cursor.com) --------- Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Signed-off-by: Stas Bekman <stas@stason.org> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
Parents
Loading