jax
d63afd84 - [Pallas GPU] Enable Pallas `OpsExtraTest` in 64-bit mode

Commit
1 year ago
[Pallas GPU] Enable Pallas `OpsExtraTest` in 64-bit mode This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode. In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications: 1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode. 2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first. 3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result. PiperOrigin-RevId: 677007613
Author
Parents
Loading