vmap support for torch.trace (#91679)
Fixes #91404
As expected
```python
import torch
from functorch import vmap
x = torch.randn(32, 3, 3, 3)
y = vmap(torch.trace)(x)
print(y)
```
Now gives the exact same runtime error as eager mode
```
(sourcetorch) ubuntu@ip-172-31-39-26:~/test$ python functorch_test_pos.py
Traceback (most recent call last):
File "functorch_test_pos.py", line 4, in <module>
y = vmap(torch.trace)(x)
File "/home/ubuntu/pytorch/torch/_functorch/vmap.py", line 420, in wrapped
return _flat_vmap(
File "/home/ubuntu/pytorch/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/home/ubuntu/pytorch/torch/_functorch/vmap.py", line 605, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
RuntimeError: trace: expected a matrix, but got tensor with dim 3
```
Equivalent eager code
```python
import torch
x = torch.randn(32, 3, 3, 3)
results = []
for xi in x:
y = torch.trace(xi)
results.append(y)
```
```
(sourcetorch) ubuntu@ip-172-31-39-26:~/test$ python functorch_test_neg.py
Traceback (most recent call last):
File "functorch_test_neg.py", line 5, in <module>
y = torch.trace(xi)
RuntimeError: trace: expected a matrix, but got tensor with dim 3
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91679
Approved by: https://github.com/zou3519