jax
bc3306c8 - [shape_poly] Improve threefry with symbolic shapes

Commit
1 year ago
[shape_poly] Improve threefry with symbolic shapes Previously, we could only handle threefry for the case when it was possible to tell statically that the size of the `count` array is even or odd. This meant that often we had to add a constraint that one of the dimensions is even. Here we rewrite the handling of threefry to not require a Python-level conditional about evenness of the size of the count array. We use a couple of `lax.dynamic_slice` rather than a `lax.split`. We also generalize the tests to cases where the size if fully symbolic, and we cannot tell statically that it is even.
Author
Committer
Parents
Loading