DeepSpeed
5fbc3eee - Ensure capacity does not exceed number of tokens (#5353)

Commit
1 year ago
Ensure capacity does not exceed number of tokens (#5353) When fine-tuning we were running into issues where the capacity would trigger the following error after some amount of time training. This was caused when the size of the inputs to top1gating were not aligned between ranks. ``` ... File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 427, in forward gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 240, in top1gating top_idx = _top_idx(mask1_rand, capacity) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 172, in _top_idx @torch.jit.script def _top_idx(source, k): return torch.topk(source, k=k, dim=0)[1] ~~~~~~~~~~ <--- HERE RuntimeError: selected index k out of range ``` Co-authored with: @rajhans Reviewed/approved by: @samyam, @yaozhewei Tagging @tohtana and @ykim362 to help review
Author
Parents
Loading