Fix a x64 error in fused_attention_stablehlo
The code assumed that `jnp.arange(x)` will return an `int32` array when
`x` is an integer. But the return type depends on the value of the x64 mode.
The fix is to specify the desired dtype explicitly.
PiperOrigin-RevId: 755376676