jax
74e8d739 - Move ROCm GPU detection to hardware_utils and add multi-GPU shm warning

Commit
12 days ago
Move ROCm GPU detection to hardware_utils and add multi-GPU shm warning - Add count_amd_gpus() and get_shm_size() to jax/_src/hardware_utils.py - count_amd_gpus() detects AMD GPUs via KFD sysfs on Linux/WSL2 - Uses functional pipeline with generators, filter, and islice - Supports early exit with stop_at parameter for performance - get_shm_size() checks /dev/shm size for multi-GPU setups - Update jax_plugins/rocm/__init__.py to use hardware_utils - Fix package names: add jax_rocm7_plugin, change jaxlib.cuda to jaxlib.rocm - Add GPU counting before plugin initialization - Warn about low /dev/shm size (<=64MB) with multiple GPUs - Changed from error to warning when no GPUs detected - Add comprehensive unit tests in tests/rocm_plugin_init_test.py - Mock filesystem operations to test without real hardware - Test GPU counting, stop_at limiting, file validation - Test /dev/shm size checking and error handling - 9 tests total, all passing
Author
Committer
Parents
Loading