[functorch] Fix index.Tensor, index_put batching rules (pytorch/functorch#862)
Fixes https://github.com/pytorch/functorch/issues/859
Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]`
in the code for details. This PR rewrites the index.Tensor and index_put
batching rules.
The TL;DR is:
- advanced indexing has different behavior depending on if the "advanced
indices are adjacent":
https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
- we have to take this into account in our batching rules, because
index.Tensor and index_put handle these internally.
Test Plan
- I added new test cases for getitem and aten.ops.index_put via OpInfo
testing.
Future
- primtorch should have a sane decomposition that we can use
- We haven't fixed the index_put_ batching rule yet. TODO later...
- Upstream our test cases (see next section) into pytorch/pytorch