[functorch] relax as_strided batching rule (#83597)
Previously there was a constraint that the bdim is required to be at
the front. As I noted in the comment in the code that I wrote years ago,
this is not necessary for correctness, we were just guarding against
potentially incorrect behavior and assumed most people would not vmap
over dimensions other than 0.
Now, the above assumption did not age very well, because we have batch
rules that return a BatchedTensor where the bdim is something other than
0 (e.g. convolution batch rule).
This PR deletes the check for that assumption and adds additional manual
tests that the as_strided batching rule works when one vmaps over a dimension
other than 0.
Automatic tests don't exist because it's a bit hard to get the
test_vmap_exhaustive test runner to replicate the strides of the inputs
faithfully.
Test Plan:
- wait for tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83597
Approved by: https://github.com/samdow