pytorch
5c67d8df - ATen lu_unpack. Required for making `torch.lu_solve` differentiable. (#46913)

Commit
3 years ago
ATen lu_unpack. Required for making `torch.lu_solve` differentiable. (#46913) Summary: Backward methods for `torch.lu` and `torch.lu_solve` require the `torch.lu_unpack` method. However, while `torch.lu` is a Python wrapper over a native function, so its gradient is implemented via `autograd.Function`, `torch.lu_solve` is a native function, so it cannot access `torch.lu_unpack` as it is implemented in Python. Hence this PR presents a native (ATen) `lu_unpack` version. It is also possible to update the gradients for `torch.lu` so that backward+JIT is supported (no JIT for `autograd.Function`) with this function. ~~The interface for this method is different from the original `torch.lu_unpack`, so it is decided to keep it hidden.~~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/46913 Reviewed By: astaff Differential Revision: D28117714 Pulled By: mruberry fbshipit-source-id: befd33db12ecc147afacac792418b6f4948fa4a4
Author
Parents
Loading