pytorch
b805e1ab - [functorch] Fix torch.cat batching rule (#86932)

Commit
2 years ago
[functorch] Fix torch.cat batching rule (#86932) The bug was discovered in https://github.com/pytorch/pytorch/pull/86842. torch.cat has an edge case where it ignores all tensors of shape [0]. So if any of the BatchedTensors have logical shape [0] but physical shape [B, 0], then we coerce them to shape [0] by slicing them. Why don't we just ignore those Tensors? We need to propagate requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of shape [B, 0] that requires grad, then the output must require grad). Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/86932 Approved by: https://github.com/Chillee
Author
Committer
Parents
Loading