jax
3c014a4c - Add support for shape polymorphism with lu_pivots_to_permutation.

Commit
1 year ago
Add support for shape polymorphism with lu_pivots_to_permutation. This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used. PiperOrigin-RevId: 662024940
Author
dfm dfm
Committer
Parents
Loading