Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) (#1453)
* Changes for bfloat16 Zero2
* ZeRO stage3 optimizations, with some bug fixes
optimizations for stage3:
- prefetching improvements
- batching allgather calls to amortize fixed overhead and improve
bandwidth utilization
- batching reduce_scatter calls to amortize fixed overhead and
improve bandwidth utilization
- using *_base variants of allgather and reduce scatter to reduce memory
allocations and data movement
- more fine grained synchronization for communication that allows
blocking on less work
- precomputation of fetching code - using a fetch queue rather than
deciding what to (pre)fetch at each iteration
- limiting queued coalesced communication ops to reduce memory pressure
on pytorch cuda caching allocator (not elegant solution)
optimizations for stage3-offload:
- made some host-device tensor copies async to improve performance
bug fixes and qol improvements:
- fix init context method when parent modules modify child weights
- speed up model initialization by moving model to GPU before weight
initialization
- fixed unit test imports so that unit tests can be run from any
directory
- change performance logging to include memory consumption
- add logging w/ model size when done partitioning model
new features
- bfloat16 support for ZeRO 3
* fix import in ut
* ran yapf
* improvements to cache flush warn log
* backwards compatibility with older versions of pytorch
* handle edge case where reduced tensor smaller than world size
* moved event synchronization to allgather handle wait() call
* removed unnecessary barrier call
* formatting fix after resolving merge conflict
* skip nvme prefetch when trace not complete
* opportunistically avoid memory allocation in allgather coalesced where possible
* fix indentation after merge
* fixes to account for parameter offload
* accounting for torch.cuda.memory_stats not being available
* moved partition_all_params to optimizer step
* allgathering on params before item gets called
* fix param status checks
needed after moving partition_all_parameters call to optimizer step
* fix grad accumulation with optimizer offload
* grad norm computation fix for optimizer offload
* change post divide in reduce-scatter to pre divide
* fix gradient race condition w/ optimizer offload
* improve inf/nan gradient tracking
* don't prefetch when not in training mode
* format fix after merging
* fix prefetching issue when using NVME offload
* improved defragmentation for fp16 parameters
* relative imports for bf16 tests
* changes for bwd compatibility with pytorch 1.2
* remove buffered_reduce_fallback
* removed unused parameter offset bookkeeping
* fixed tracking for multiple param groups
* unbroke bfloat16 config after merge conflict
* using base allgather params when only 1 param
* cleanup/fixes for fp16 partition defragmentation
* switch to CRLF
* convert to same new-line style as master
* align new line with master
* Fix merge issues
* switch to CRLF
* fix to LF line endings
* minor merge fixes
* remove extra bfloat16_enabled definition
* asserting params inflight for AllGatherHandle
* remove get_cuda_mem_allocated_str
* Format fixes
* fix bfloat16 zero stage check (broken after merge commit)
* +self.communication_data_type, -self.allreduce_always_fp32; delete dead code
* Add self.reduce_scatter
* Format fix
* Fix merge issues
* iterate over params_to_fetch rather than make another iterator
* add some TODOs
* remove unnecessary division by micro_step_id
* rename config keys "bfloat16" -> "bf16"
* rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bit_weights_on_model_save
* add unit test to check backwards compatibility for gather_16bit_weights
* added test to confirm bf16 key bwd compatibility
* Format fixes
Co-authored-by: Rana Ali Amjad <raamjad@amazon.com>
Co-authored-by: Justin Chiu <justchiu@amazon.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>