jax
1e84cbe9 - [jax2tf] Fix random.split when jax_exable_x64 (#4208)

Commit
5 years ago
[jax2tf] Fix random.split when jax_exable_x64 (#4208) Since we do the threefry with signed integers when converting to TF, we run into the type promotion 'uint32 - int32 = int64', which then results in lax.shift_right_logical(uint32, int64), which fails.
Author
Parents
Loading