onnxruntime
5996a1ec - feat(orttraining): add CPU fallback for FusedAdam optimizer (#28233)

Commit
1 day ago
feat(orttraining): add CPU fallback for FusedAdam optimizer (#28233) ## Summary - `FusedAdam.__init__` now detects `torch.cuda.is_available()` and falls back to a standard PyTorch optimizer on CPU instead of crashing. - A one-time `UserWarning` informs the user that the fused CUDA kernel is unavailable and a CPU implementation is in use. - `step()` and `zero_grad()` delegate to the fallback when present; the CUDA path is unchanged. ## Motivation On CPU-only PyTorch builds, `FusedAdam` raises immediately in `__init__` because it unconditionally: 1. Allocates `torch.cuda.IntTensor([0])` as an overflow buffer. 2. Imports the CUDA-only C++ extension `onnxruntime.training.ortmodule.torch_cpp_extensions.fused_ops`. This makes it impossible to use `FusedAdam` in CPU-only test/dev environments or to write code that transparently works on either device. The maintainer (@baijumeswani) confirmed in the issue that a CPU fallback with a warning is the desired fix. Fixes #17403 ## Changes `orttraining/orttraining/python/training/optim/fused_adam.py`: - Wrap the two CUDA-specific allocations in `if torch.cuda.is_available()`. - On CPU, build `self._cpu_fallback_optimizer` based on `adam_w_mode`: - `ADAM_L2_REGULARIZATION` → `torch.optim.Adam` (weight_decay applied as L2 regularization) - `ADAMW_TORCH` → `torch.optim.AdamW` - `ADAMW_TRANSFORMERS` → `transformers.AdamW` (with `torch.optim.AdamW` fallback when `transformers` is not installed, plus a second warning) - Emit a single `UserWarning` per instance. - `step()` and `zero_grad()` early-return through the fallback when set. - Update the docstring to drop the "GPU-only" claim. `orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py` (new): - Patches `torch.cuda.is_available()` to return `False` so tests run deterministically on any host. - Asserts instantiation succeeds and emits a `UserWarning`. - Asserts a single `step()` produces parameter updates equivalent to `torch.optim.AdamW`. - Asserts `AdamWMode.ADAM_L2_REGULARIZATION` instantiates and steps without raising. ## Test Plan - `python -m pytest orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py -v` — 3 passed. - `lintrunner -a` on both files — clean, no changes applied. - The CUDA code path is byte-for-byte unchanged in behavior; only wrapped in a conditional. No behavioral change for existing GPU users.
Author
Parents
Loading