Normalize ZeRO-3 DeepCompile grad dtype before reduction (#8038)
Some backward kernels produce gradients in their computation dtype, not
necessarily in the parameter storage dtype. For example, if a backward
path accumulates or promotes math in fp32, a parameter stored as bf16
can still receive an fp32 raw gradient from that backward computation.
In normal PyTorch execution, that raw gradient reaches the leaf-gradient
accumulation step, which stores it according to the tensor's expected
grad dtype. ZeRO-3 DeepCompile intercepts the raw compiled-backward
gradient before that leaf accumulation boundary. The reducer was
assuming the raw gradient dtype was already the expected leaf grad
dtype, so it could select an fp32 communication bucket even when the
ZeRO grad partition storage was bf16.
To address this, this PR changes `dc.reduce_grad`'s behavior to match
PyTorch's leaf-gradient dtype contract. ZeRO-3 registration now records
the expected grad dtype for each parameter, and `reduce_grad` normalizes
raw compiled-backward gradients to that dtype before selecting the
communication bucket.
This follows the documented `grad_dtype` behavior, including preserving
explicit `grad_dtype=None` opt-outs:
https://docs.pytorch.org/docs/main/generated/torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.html#torch.sparse.semi_structured.SparseSemiStructuredTensorCUSPARSELT.grad_dtype
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>