Commit
3 years ago
Add linalg.lu This PR modifies `lu_unpack` by: - Using less memory when unpacking `L` and `U` - Fuse the subtraction by `-1` with `unpack_pivots_stub` - Define tensors of the correct types to avoid copies - Port `lu_unpack` to be a strucutred kernel so that its `_out` version does not incur on extra copies Then we implement `linalg.lu` as a structured kernel, as we want to compute its derivative manually. We do so because composing the derivatives of `torch.lu_factor` and `torch.lu_unpack` would be less efficient. This new function and `lu_unpack` comes with all the things it can come: forward and backward ad, decent docs, correctness tests, OpInfo, complex support, support for metatensors and support for vmap and vmap over the gradients. I really hope we don't continue adding more features. This PR also avoids saving some of the tensors that were previously saved unnecessarily for the backward in `lu_factor_ex_backward` and `lu_backward` and does some other general improvements here and there to the forward and backward AD formulae of other related functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67833 Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry
Author
Committer
Parents
Loading