jax
a981e1c4 - Start adding primitive registration helper functions to lax.linalg.

Commit
317 days ago
Start adding primitive registration helper functions to lax.linalg. As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff. PiperOrigin-RevId: 729471970
Author
dfm dfm
Parents
Loading