pytorch
9dd18247 - Fix dispatch keys for eigh, lu_solve (#60945)

Commit
3 years ago
Fix dispatch keys for eigh, lu_solve (#60945) Summary: I added a test to `test_ops.py` that verifies that the op can run correctly from different cuda devices. This test revealed that `linalg_eigh`, `linalg_eigvalsh`, `linalg_matrix_rank`, `linalg_pinv` were failing. `matrix_rank` and `pinv` are calling `eigh` internally. `linalg_eigh` and `lu_solve` internally use dispatch stubs, so they should be registered with `CPU, CUDA` dispatch keys. The generated code includes device guards in this case and the problem is not present. Implemented a better out variant for `eigvalsh` and registered it with `CPU, CUDA` dispatch keys. ~I added a device guard to `linalg_eigh_kernel` as a fix for `eigvalsh` function. This function needs to be registered as CompositeImplicitAutograd, because it calls `at::linalg_eigh` if `at::GradMode::is_enabled()`.~ Fixes https://github.com/pytorch/pytorch/issues/60892. Pull Request resolved: https://github.com/pytorch/pytorch/pull/60945 Reviewed By: mruberry Differential Revision: D29589580 Pulled By: ngimel fbshipit-source-id: 5851605958bdfc3a1a1768263934619449957168
Author
Parents
Loading