jax
7c6c2bb6 - Use dtype=int instead of explicit dtype=jnp.int64 in transposed ragged dot MGPU kernel, so it works for both x64 and x32 configs.

Commit
46 days ago
Use dtype=int instead of explicit dtype=jnp.int64 in transposed ragged dot MGPU kernel, so it works for both x64 and x32 configs. PiperOrigin-RevId: 829096722
Parents
Loading