jax
63b94d5c - Enable CUDA Array Interface tests on ROCm platform

Commit
81 days ago
Enable CUDA Array Interface tests on ROCm platform This change enables CUDA Array Interface tests on ROCm GPUs: 1. jaxlib/py_array.cc: Add ROCm platform check alongside CUDA for __cuda_array_interface__ property support. 2. jax/_src/numpy/array_constructors.py: - Add ROCm plugin extension discovery with debug logging 3. tests/array_interoperability_test.py: Change the following tests to use @jtu.run_on_devices("gpu") to run on both CUDA and ROCm: - testCaiToJax - testCudaArrayInterfaceWorks - testCudaArrayInterfaceBfloat16Fails - testCudaArrayInterfaceOnShardedArrayFails 4. tests/array_interoperability_test.py: Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform by adding "rocm" to @jtu.skip_on_devices decorator.
Author
Committer
Parents
Loading