pytorch
5842e5c1 - vmap support for torch.tril and torch.triu (#94287)

Commit
2 years ago
vmap support for torch.tril and torch.triu (#94287) Summary: Add vmap support for torch.tril and torch.triu. Fix: #91403 Test Plan: GitHub pipeline Differential Revision: D43016624 ### Expected behavior Same as using for-loop: ```python import torch x = torch.randn(32, 3) results = [] for xi in x: y = torch.triu(xi) results.append(y) """ triu: input tensor must have at least 2 dimensions --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-7-d726203efb0e> in <module> 4 results = [] 5 for xi in x: ----> 6 y = torch.triu(xi) 7 results.append(y) RuntimeError: triu: input tensor must have at least 2 dimensions """ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94287 Approved by: https://github.com/Skylion007, https://github.com/zou3519
Author
Committer
Parents
Loading