pytorch
8aca85db - Add diagflat complex support (#47564)

Commit
5 years ago
Add diagflat complex support (#47564) Summary: Adds complex numbers support for `torch.diag` ``` python >>> import torch >>> a = torch.ones(2, dtype=torch.complex128) >>> torch.diagflat(a) tensor([[1.+0.j, 0.+0.j], [0.+0.j, 1.+0.j]], dtype=torch.complex128) >>> b = a.cuda() >>> torch.diagflat(b) tensor([[1.+0.j, 0.+0.j], [0.+0.j, 1.+0.j]], device='cuda:0', dtype=torch.complex128) ``` Note that automatic differentiation isn't implemented: ``` python >>> d = torch.ones(1, dtype=torch.complex128, requires_grad=True) >>> torch.diagflat(d) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: diag does not support automatic differentiation for outputs with complex dtype. ``` Fixes https://github.com/pytorch/pytorch/issues/47499 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47564 Reviewed By: heitorschueroff Differential Revision: D24844467 Pulled By: anjali411 fbshipit-source-id: 9c8cb795d52880b7dcffab0c059b0f6c2e5ef151
Author
Parents
Loading