pytorch
c362138f - Disallow passing functions that don't return Tensors to vmap (#40518)

Commit
4 years ago
Disallow passing functions that don't return Tensors to vmap (#40518) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40518 I overlooked this in the initial vmap frontend api PR. Right now we want to restrict vmap to taking in functions that only return Tensors. A function that only return tensors can look like one of the following: ``` def fn1(x): ... return y def fn2(x): ... return y, z ``` fn1 returns a Tensor, while fn2 returns a tuple of Tensors. So we add a check that the output of the function passed to vmap returns either a single tensor or a tuple of tensors. NB: These checks allow passing a function that returns a tuple with a single-element tensor from vmap. That seems OK to me. Test Plan: - `python test/test_vmap.py -v` Differential Revision: D22216166 Pulled By: zou3519 fbshipit-source-id: a92215e9c26f6138db6b10ba81ab0c2c2c030929
Author
Parents
Loading