Ensure capacity does not exceed number of tokens (#5353)
When fine-tuning we were running into issues where the capacity would
trigger the following error after some amount of time training. This was
caused when the size of the inputs to top1gating were not aligned
between ranks.
```
...
File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 427, in forward
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 240, in top1gating
top_idx = _top_idx(mask1_rand, capacity)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 172, in _top_idx
@torch.jit.script
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
~~~~~~~~~~ <--- HERE
RuntimeError: selected index k out of range
```
Co-authored with: @rajhans
Reviewed/approved by: @samyam, @yaozhewei
Tagging @tohtana and @ykim362 to help review