DeepSpeed
00ea0c46 - Zero2: avoid graph breaks in torch.compile by using param_idx (#6803)

Commit
364 days ago
Zero2: avoid graph breaks in torch.compile by using param_idx (#6803) inside reduce_independent_p_g_buckets_and_remove_grads and in reduce_ipg_grads which are being executed during the BWD hook in zero2, the model param is being stored inside params_in_ipg_bucket. torch.compile has hard time tracing parameters. By using the param's static index inside the group the same logic can be maintain with less complexity. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams <loadams@microsoft.com>
Author
Parents
Loading