Zero2: avoid graph breaks in torch.compile by using param_idx (#6803)
inside reduce_independent_p_g_buckets_and_remove_grads and in
reduce_ipg_grads which are being executed during the BWD hook in zero2,
the model param is being stored inside params_in_ipg_bucket.
torch.compile has hard time tracing parameters.
By using the param's static index inside the group the same logic can be
maintain with less complexity.
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>