pytorch
845e4b8a - [fix] legacybatching: getPhysicalDims (#93261)

Commit
1 year ago
[fix] legacybatching: getPhysicalDims (#93261) Fixes #92985 Minimum Repro: ```python import torch from torch._vmap_internals import vmap input = torch.randn(2, 2) def fn(x): return x.sum(()) o = vmap(fn)(input) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93261 Approved by: https://github.com/albanD, https://github.com/Skylion007
Author
Committer
Parents
Loading