jax
8adb773c - Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline.

Commit
237 days ago
Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline. These rules are similar to a `unary` op except that they only compute flops for the given axis. `cumlogsumexp` takes twice as many ops given the complexity of that function. PiperOrigin-RevId: 772608005
Author
Committer
Parents
Loading