Disallow passing functions that don't return Tensors to vmap (#40518)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40518
I overlooked this in the initial vmap frontend api PR. Right now we
want to restrict vmap to taking in functions that only return Tensors.
A function that only return tensors can look like one of the following:
```
def fn1(x):
...
return y
def fn2(x):
...
return y, z
```
fn1 returns a Tensor, while fn2 returns a tuple of Tensors. So we add a
check that the output of the function passed to vmap returns either a
single tensor or a tuple of tensors.
NB: These checks allow passing a function that returns a tuple with a
single-element tensor from vmap. That seems OK to me.
Test Plan: - `python test/test_vmap.py -v`
Differential Revision: D22216166
Pulled By: zou3519
fbshipit-source-id: a92215e9c26f6138db6b10ba81ab0c2c2c030929