pytorch
d4aa807b - Enable bfloat16 for hardtanh_backward_cuda (#91511)

Commit
2 years ago
Enable bfloat16 for hardtanh_backward_cuda (#91511) I'm not sure why this was left out in the first place as all adjacent operations have both Half and BFloat16. Things seem to work as expected and this enables `relu6` to be used in bfloat16 training. Hardtanh backward is super simple and precision is not relevant. ``` import torch x_fp32 = torch.tensor([-1,2,4,7], requires_grad=True, dtype=torch.float32, device="cuda") x_bf16 = torch.tensor([-1,2,4,7], requires_grad=True, dtype=torch.bfloat16, device="cuda") torch.nn.functional.relu6(x_fp32).sum().backward() torch.nn.functional.relu6(x_bf16).sum().backward() assert (x_fp32.grad == x_bf16.grad).all() ``` Previously would fail with: ``` Traceback (most recent call last): File "test_hardtanh_patch.py", line 5, in <module> torch.nn.functional.relu6(x_bf16).sum().backward() File ".../lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File ".../lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: "hardtanh_backward_cuda" not implemented for 'BFloat16' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91511 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading