[functorch] Fix torch.cat batching rule (#86932)
The bug was discovered in https://github.com/pytorch/pytorch/pull/86842.
torch.cat has an edge case where it ignores all tensors of shape [0]. So
if any of the BatchedTensors have logical shape [0] but physical shape
[B, 0], then we coerce them to shape [0] by slicing them.
Why don't we just ignore those Tensors? We need to propagate
requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of
shape [B, 0] that requires grad, then the output must require grad).
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86932
Approved by: https://github.com/Chillee