Optimize zero3 fetch params using all_reduce (#5420)
* Use all_reduce instead of all_gather to fetch module parameters. This
improves performance by reducing the overhead of concatenation and
slicing, which are no longer required.
* Instead, all tensors views are created prior to the collective
(all_reduce), so upon its completion only the parameter status is
updated.
* The behavior is enabled via a new boolean flag under the section
"zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true }
* By default the optimization is not enabled.
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>