DeepSpeed
91f58c06 - [zero3] params_to_reduce isn't always there (#1214)

Commit
4 years ago
[zero3] params_to_reduce isn't always there (#1214) * [zero3] params_to_reduce isn't always there Trying to port HF's Electra model's to Deepspeed I'm getting this on the very first backward step (with some extra debug): ``` Incrementing with parameter id 42 ------ Before allocating allgather param name=generator_lm_head.weight id=41 shape=torch.Size([1]) status=ZeroParamStatus.NOT_AVAILABLE partition size=327680 ------allgather param with name=generator_lm_head.weight id=41 shape=torch.Size([1]) status=ZeroParamStatus.NOT_AVAILABLE partition size=327680 ------ Before allocating allgather param name=generator_lm_head.bias id=42 shape=torch.Size([1]) status=ZeroParamStatus.NOT_AVAILABLE partition size=5120 ------allgather param with name=generator_lm_head.bias id=42 shape=torch.Size([1]) status=ZeroParamStatus.NOT_AVAILABLE partition size=5120 Backward name=generator_lm_head.weight id=41 shape=torch.Size([5120, 64]) Inside reduce ipg buckets. name=generator_lm_head.weight id=41 shape=torch.Size([5120, 64]), ipg elements 0, reduce bucket size 4096 Params in ipg bucket [] Reducing [] GOT 1 torch.Size([4096]) Traceback (most recent call last): File "examples/pytorch/language-modeling/run_mlm.py", line 533, in <module> main() File "examples/pytorch/language-modeling/run_mlm.py", line 484, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) File "/mnt/nvme1/code/huggingface/transformers-ds-zero_to_fp32-tests/src/transformers/trainer.py", line 1269, in train tr_loss += self.training_step(model, inputs) File "/mnt/nvme1/code/huggingface/transformers-ds-zero_to_fp32-tests/src/transformers/trainer.py", line 1778, in training_step loss = self.deepspeed.backward(loss) File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/engine.py", line 1188, in backward self.optimizer.backward(loss) File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/zero/stage3.py", line 2964, in backward self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward scaled_loss.backward(retain_graph=retain_graph) File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward Variable._execution_engine.run_backward( File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/zero/stage3.py", line 1867, in reduce_partition_and_remove_grads self.reduce_ready_partitions_and_remove_grads(param, i) File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/zero/stage3.py", line 2212, in reduce_ready_partitions_and_remove_grads self.reduce_independent_p_g_buckets_and_remove_grads(param, i) File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/zero/stage3.py", line 1897, in reduce_independent_p_g_buckets_and_remove_grads self.reduce_ipg_grads() File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/zero/stage3.py", line 2193, in reduce_ipg_grads self.average_tensor(reduction_list, params_to_reduce) File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zero-init-child-only-post_init/deepspeed/runtime/zero/stage3.py", line 1972, in average_tensor params_to_reduce[0].reduce_gradients_at_owner( ``` Is it always that `params_to_reduce` is populated? If I add this check the problem goes away it seems. * real fix
Author
Parents
Loading