jax
9c65f62d - Fix a x64 error in fused_attention_stablehlo

Commit
277 days ago
Fix a x64 error in fused_attention_stablehlo The code assumed that `jnp.arange(x)` will return an `int32` array when `x` is an integer. But the return type depends on the value of the x64 mode. The fix is to specify the desired dtype explicitly. PiperOrigin-RevId: 755376676
Author
Parents
Loading