pytorch
71ad1005 - Add prelu into Autocast CPU whitelist (#95366)

Commit
3 years ago
Add prelu into Autocast CPU whitelist (#95366) ### Motivation Add `prelu` to lower precision cast policy on AutocastCPU to fix https://github.com/pytorch/pytorch/issues/95365 : Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError. After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type. Before https://github.com/pytorch/pytorch/pull/91238, when input is `bf16`, weight will be forced to cast to `bf16`. After https://github.com/pytorch/pytorch/pull/91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`. And this also alighs with Autocast CUDA whitelist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95366 Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
Author
Committer
Parents
Loading