Change mixed precision API
Change the API such that if a precision type is not explicitly speciifed in `MixedPrecision` config, we don't cast it at all.
Previously, we would default to fp16 reduced precision, but this doesn't account for the case where user might want to, for example, use only reduced gradient comm precision. Trying to do this via:
```
MixedPrecision(reduce_dtype=torch.float16, param_dtype=torch.float32, buffer_dtype=torch.float32)
```
does not work for all use cases because the code will still attempt to cast params to fp32, but user's model may be assuming double type/fp64 somewhere.
Now, specifying
```
MixedPrecision(reduce_dtype=torch.float16)
```
only affects gradient comm precision, and does not touch casting of params / buffers.
Note that if user specifies reduced precision for only parameters, gradients will be of this reduced type in _post_backward_hook and are therefore communicated in this precision. Therefore, the priority of precision in which grads are communicated is:
reduce_dtype, if specified -> param_dtype, if specified -> full precision param type.
We take additional care to make sure grads are cast back to the full precision param type for optimizer step: either if parameter was in reduced precision or if parameter was not in reduced precision but reduced gradient precision for comm was configured.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76423
Approved by: https://github.com/zhaojuanmao