vmap: Fix bug with x * 0.1 (#43218)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43218
Previously, `vmap(lambda x: x * 0.1)(torch.ones(3))` would return a
float64 tensor(!!). This is because there is a subtle bug in the
batching rule: the batching rule receives:
- A batched tensor for x
- a scalar tensor: tensor(0.1, dtype=torch.float64).
The batching rule decides to expand the scalar tensor to be the same
size as x and then multiplies the two tensors, promoting the output to
be a float64 tensor. However, this isn't correct: we should treat the
scalar tensor like a scalar tensor. When adding a FloatTensor to a
Double scalar tensor, we don't promote the type usually.
Another example of a bug this PR fixes is the following:
`vmap(torch.mul)(torch.ones(3), torch.ones(3, dtype=torch.float64))`
Multiplying a scalar float tensor with a scalar double tensor produces a
float tensor, but the above produced a float64 before this PR due to
mistakingly type-promoting the tensors.
Test Plan:
- new test: `pytest test/test_vmap.py -v`
- I refactored some tests a bit.
Reviewed By: cpuhrsch
Differential Revision: D23195418
Pulled By: zou3519
fbshipit-source-id: 33b7da841e55b47352405839f1f9445c4e0bc721