[vmap] fix reduction boxed batching rules (#91109)
Fixes https://github.com/pytorch/pytorch/issues/91041
There's a bug in our boxed reduction batching rules for a very specific
case: vmap over a Tensor of shape [1] for an operation where the
output rank is supposed to be less than the input rank, e.g.
```
x = torch.tensor([10.], device=device)
y = vmap(lambda x: x.sum(0))(x)
```
The boxed reduction batching rule handles three types of "reduction"
operations:
- reduction operations with an optional keepdim argument, which
specifies if the output should have the same or smaller rank than the
input
- reduction operations without a keepdim arg that morally have keepdim=True (like cumsum --
which never actually modifies the rank of the tensor but is still a
"reduction" since it sums a bunch of things together)
- reduction operations without a keepdim arg that morally have
keepdim=False. (just torch.count_nonzero).
Furthermore, PyTorch has special handling for scalar tensors (e.g.
tensors of shape []). It is valid to do
`torch.sum(torch.tensor(10.), dim=0)`.
This PR updates the `boxed_reduction_batch_rule` to handle the
interaction between the three kinds of reduction and the scalar tensor
cases correctly. Concretely, it:
- introduces additional templates to `boxed_reduction_batch_rule` for
what type of "keepdim" reduction this is.
- splits the old REDUCTION_BOXED macro (which was a good default) into
REDUCTION_NO_KEEPDIM_ARG and REDUCTION_WITH_KEEPDIM_ARG (which are also
opionated defaults) and uses them.
Test Plan:
- Given an input of shape [], our vmap OpInfo test suite only produces
a Tensor of shape [B] with B = 2. At first glance this doesn't look
sufficient to test this case (vmap over Tensor[1]), but the claim is
that it is because the boxed_reduction_batch_rule is agnostic to the shape
of the dimension being vmapped over. Previously it was not due to
the semantics of `squeeze`; this PR adds internal asserts to make it agnostic.
- there is a light test for vmap over the Tensor of shape [1] for
torch.sum as a sanity check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91109
Approved by: https://github.com/samdow