pytorch
10bf20da - Change mixed precision API

Commit
3 years ago
Change mixed precision API Change the API such that if a precision type is not explicitly speciifed in `MixedPrecision` config, we don't cast it at all. Previously, we would default to fp16 reduced precision, but this doesn't account for the case where user might want to, for example, use only reduced gradient comm precision. Trying to do this via: ``` MixedPrecision(reduce_dtype=torch.float16, param_dtype=torch.float32, buffer_dtype=torch.float32) ``` does not work for all use cases because the code will still attempt to cast params to fp32, but user's model may be assuming double type/fp64 somewhere. Now, specifying ``` MixedPrecision(reduce_dtype=torch.float16) ``` only affects gradient comm precision, and does not touch casting of params / buffers. Note that if user specifies reduced precision for only parameters, gradients will be of this reduced type in _post_backward_hook and are therefore communicated in this precision. Therefore, the priority of precision in which grads are communicated is: reduce_dtype, if specified -> param_dtype, if specified -> full precision param type. We take additional care to make sure grads are cast back to the full precision param type for optimizer step: either if parameter was in reduced precision or if parameter was not in reduced precision but reduced gradient precision for comm was configured. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76423 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading