DeepSpeed
456c9ac6 - Stage3: Use new torch grad accumulation hooks API (#6773)

Commit
262 days ago
Stage3: Use new torch grad accumulation hooks API (#6773) * This commit addresses a Deepspeed issue [#6718](https://github.com/microsoft/DeepSpeed/issues/6718) * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as `param.data = replicated_tensor.data` used in `allgather_params(..)` are compiled into `param.set()` causing the hook assigned to the grad_acc node not being called. * Starting from PyTorch 2.1 there is a new and robust hook API on a param itself: `param.register_post_accumulate_grad_hook(..)` * This commit will make use of the proper API depending on the PyTorch version * It will also disable compile for PyTorch versions < 2.1 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Author
Parents
Loading