fix: topkgating major bug (#7986)
`topk_masked_gates` was previously being used across the tokens
dimension to determine which tokens had the highest importance for each
expert. However, it was using logits rather than probabilities to
determine this.
This was causing print statements like
```
if dist.get_rank() == 0:
print(f"Mask mean: {mask.float().mean()}")
print(f"Capacity mask mean: {capacity_mask.mean()}")
mask = torch.logical_and(mask, capacity_mask)
if dist.get_rank() == 0:
print(f"Mask (after AND) mean: {mask.float().mean()}")
```
to often yield values like
```
Mask mean: 0.0625
Capacity mask mean: 0.0625
Mask (after AND) mean: 0.005908316932618618
```
and in turn the average number of routed experts per token was as low as
`0.001`.
---------
Signed-off-by: Daniel Shen <dandanshen2002@gmail.com>