pytorch
bc03aa60 - Store `autocast_gpu_dtype` in `custom_fwd` and `custom_bwd` for BFloat16 autocast (#88029)

Commit
2 years ago
Store `autocast_gpu_dtype` in `custom_fwd` and `custom_bwd` for BFloat16 autocast (#88029) As per #87979, `custom_bwd` seems to forcefully use `torch.float16` for `torch.autograd.Function.backward` regardless of the `dtype` used in the forward. Changes: - store the `dtype` in `args[0]` - update tests to confirm the dtype of intermediate result tensors that are outputs of autocast compatible `torch` functions cc @ptrblck @ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/88029 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading