pytorch
1cad7446 - Enable select.int when NestedTensor requires grad (#83875)

Commit
2 years ago
Enable select.int when NestedTensor requires grad (#83875) Previously indexing a nested tensor when it requires_grad would raise an error because the backward formula for `select.int` uses `self.sizes()`. This PR fixes that by temporarily registering a _nested_select_backward function which can be removed when we start using the symint approach to register kernels. For now this functionality is needed for creating a POC that nested tensor can be an API to `segment_coo` and `segment_csr` in the torch_scatter repo ``` a = torch.arange(10).reshape(2, 5).float() b = torch.arange(12).reshape(2, 6).float() nt = torch.nested_tensor([a, b], dtype=torch.float).requires_grad_(True) nt[0] # RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor ``` whereas ``` nt = torch.nested_tensor([a, b], dtype=torch.float).requires_grad_(False) nt[0] ``` would succeed Pull Request resolved: https://github.com/pytorch/pytorch/pull/83875 Approved by: https://github.com/albanD, https://github.com/drisspg
Committer
Parents
Loading