DeepSpeed
9c86cd98 - z2: don't pass `dtype` to `report_ipg_memory_usage` (#7636)

Commit
65 days ago
z2: don't pass `dtype` to `report_ipg_memory_usage` (#7636) This PR is fixing this: ``` [rank0]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py", line 985, in grad_handling_hook [rank0]: self.process_gradients(param, i) [rank0]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py", line 1524, in process_gradients [rank0]: self.reduce_ready_partitions_and_remove_grads(param, i) [rank0]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py", line 1528, in reduce_ready_partitions_and_remove_grads [rank0]: self.reduce_independent_p_g_buckets_and_remove_grads(param, i) [rank0]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py", line 1006, in reduce_independent_p_g_buckets_and_remove_grads [rank0]: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel(), param.dtype) [rank0]: File "/code/users/stas/github/DeepSpeed/deepspeed/runtime/base_optimizer.py", line 70, in report_ipg_memory_usage [rank0]: bucket = self.ipg_buckets[dt] [rank0]: ~~~~~~~~~~~~~~~~^^^^ [rank0]: KeyError: torch.bfloat16 ``` the problem doesn't exist if: `seq_parallel_communication_data_type: bf16` is used, but fails with `fp32` (or no setting). In this PR I'm syncing with the z3 implementation which doesn't pass the `dtype` arg and lets the traversal of existing dtypes do the thing. https://github.com/deepspeedai/DeepSpeed/blob/407708cdb6e48dbff971b0f03ec4613d0f084a4b/deepspeed/runtime/base_optimizer.py#L66-L75 Fixes: https://github.com/deepspeedai/DeepSpeed/issues/7607
Author
Parents
Loading