DeepSpeed
a5b04f02 - Remove get_accelerator().amp() API and use torch.amp directly

Commit
23 days ago
Remove get_accelerator().amp() API and use torch.amp directly The amp() method on each accelerator returned a device-specific torch.<device>.amp module, but since PyTorch 2.4 the unified torch.amp API (torch.amp.custom_fwd, torch.amp.custom_bwd, torch.amp.autocast) accepts a device_type argument and works across all backends. The previous commit already migrated the two call sites; this commit removes the now-unused amp() abstract method and all 8 accelerator implementations, plus simplifies the custom_fwd/custom_bwd setup in zero/linear.py by dropping the pre-2.4 fallback path. Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
Author
Parents
Loading