jax
a5426f8a - Enable neural network tests on ROCm

Commit
113 days ago
Enable neural network tests on ROCm Enable NN tests to run on ROCm by using XLA implementation instead of cuDNN (which is NVIDIA-only) and fixing compute capability skip checks. Changes: - testScaledMatmul: skip compute capability check on ROCm (works on ROCm) - testScaledDotGeneral: skip compute capability check on ROCm (works on ROCm) - testDotProductAttention: add ROCm skip for cuDNN impl (XLA impl still runs) - testDotProductAttentionMask: use XLA instead of cuDNN on ROCm - testDotProductAttentionBiasGradient: use XLA instead of cuDNN on ROCm
Author
Committer
Parents
Loading