pytorch
39fa0b5d - Add scatter_add to amp promote list (#52133)

Commit
3 years ago
Add scatter_add to amp promote list (#52133) Summary: Fixes https://github.com/pytorch/pytorch/issues/51730 I've added the `scatter_add` and `scatter_add.dimname` to the promote list as well as test cases for the former op. However, it seems that `scatter_add` [doesn't support named tensors yet](https://github.com/pytorch/pytorch/blob/8b0cb5ede3eddb96aa0423b2c73c8560ab44788e/aten/src/ATen/native/NamedTensor.cpp#L356-L358) (thanks t-vi for the pointer): ```python dev = 'cuda' torch.scatter_add(torch.zeros(2, 2, 2, dtype=torch.float16, device=dev, names=('N', 'C', 'L')), 'C', torch.randint(0, 2, (2, 2, 2), device=dev), torch.randn((2, 2, 2), dtype=torch.float32, device=dev)) > RuntimeError: scatter_add: You passed a dimname (string) to this op in place of a dimension index but it does not yet support this behavior. Please pass a dimension index to work around this. ``` which raised this error after adding this test case. I'm thus unsure, if I should also remove `scatter_add.dimname` from the promote list or not. In any case, once named tensors are supported a potential test could be added as: ```python ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev, names=('N', 'C', 'L')), 'C', torch.randint(0, 2, (2, 2, 2), device=dev), torch.randn((2, 2, 2), dtype=torch.float32, device=dev))), ``` CC mcarilli ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/52133 Reviewed By: ejguan Differential Revision: D26440392 Pulled By: ngimel fbshipit-source-id: f4ee2d0b9e1f81afb6f94261c497cf2bf79ec115
Author
Parents
Loading