Mixed-precision: per-policy param/buffer dtype cast (preserve fp32 buffers) (#8066)
## Summary
- Add `data_types.param_dtype` and `data_types.buffer_dtype` (both
default `None`), mirroring FSDP `MixedPrecisionPolicy`.
- Replace the blanket `module.half()` / `module.bfloat16()` in
`_configure_distributed_model` with a targeted cast: parameters go to
`param_dtype`; floating buffers keep their loaded dtype unless
`buffer_dtype` is explicitly set.
## Motivation
The blanket cast downcasts every floating buffer, including the rotary
`inv_freq` buffer that HF/FSDP2 keep in fp32. On long contexts the bf16
`inv_freq` loses precision, RoPE angles drift, and logits/grads diverge
from the FSDP2 reference. Preserving fp32 buffers by default fixes this;
`buffer_dtype` is the escape hatch to reproduce the legacy behavior.
## Behavior
- `param_dtype` unset -> derived from the fp16/bf16 enabled flag (legacy
param behavior).
- `buffer_dtype` unset -> buffers keep their loaded dtype (e.g. fp32
`inv_freq`).
- `buffer_dtype` set -> buffers force-cast (legacy blanket-cast parity).
## Test plan
- [ ] `param_dtype=bf16`, `buffer_dtype` unset -> params bf16,
`inv_freq` stays fp32.
- [ ] `buffer_dtype=bf16` -> buffers downcast (legacy parity).
- [ ] bf16/fp16 run with neither key set behaves as before except fp32
buffers preserved.
- [ ] 8B / 32B ZeRO-3 long-context run -> grad_norm tracks the FSDP2
reference.
Made with [Cursor](https://cursor.com)
---------
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Signed-off-by: Stas Bekman <stas@stason.org>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>