pytorch
faf9c47a - Simplify a few diagonal-related functions (#87180)

Commit
3 years ago
Simplify a few diagonal-related functions (#87180) `diag` was unnecessarily implemented as a kernel rather than as a composite function, which made it unnecessarily difficult (explicit backward + all it entails). We also change a few uses of `diag` on 2D tensors for `diagonal()`. The latter returns a view rather than creating a new tensor. We also upgrade its meta implementation to a fully-fledged decomposition I tried implementing the backwards of `diagonal()` via `diag_scatter` (or better `diag_scatter_` to keep the perf) but functionalisation was failing and I was not sure how to fix this, so I moved on. It may be possible to simplify that one as well if @soulitzer or someone knows how to do this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87180 Approved by: https://github.com/ngimel, https://github.com/albanD, https://github.com/mruberry
Author
Committer
Parents
Loading