ZeRO3, improved parameter all-gather operation (#1188)
* remove norm(), avoid memcpy after allgather
1) Removing the norm computation in debug printing
2) Changing _all_gather to be sync op in fetch_sub_module
Reason: the async version is not async at all, because each
all_gather calls torch.cuda.synchronize() to guarantee previous
communication op to be completed
3) Adding new function _allgather_params_split_launch
the existing _allgather_params has explicit memcpy after the
all-gather op. We can avoid the explicit memory copy at
python side, to improve the performance.
Known issue:
the `torch.distributed.all_gather` will do implicit memcpy
at the end of each `ncclAllgather`.
* WIP: wrapped ncclAllgather as customized op in DS
micro benchmark shows the improvement of allgather a
transformer layer with 9834560 elements in half precision is about
1.1ms on aws-p4d instance.
* WIP: integrated into partition_parameters
Performance improvement of 5.1B bert on aws-p4d:
fwd: 300ms -> 200ms
bwd: 680ms -> 610ms
* Fix format
* cleaned dead code, modified unit test
* removed customized c++ extension
revert back to use torch distributed API
* change torch.ones to torch empty
* typo
* warn if not cuda tensor for allgather
* fix formatting
* fix: move ds_tensor to cuda device
but it is strange that the ds_tensor haven't been moved to cuda
* remove try clause on the path for fetching params
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>