DeepSpeed
853c9389 - fix: topkgating major bug (#7986)

Commit
10 days ago
fix: topkgating major bug (#7986) `topk_masked_gates` was previously being used across the tokens dimension to determine which tokens had the highest importance for each expert. However, it was using logits rather than probabilities to determine this. This was causing print statements like ``` if dist.get_rank() == 0: print(f"Mask mean: {mask.float().mean()}") print(f"Capacity mask mean: {capacity_mask.mean()}") mask = torch.logical_and(mask, capacity_mask) if dist.get_rank() == 0: print(f"Mask (after AND) mean: {mask.float().mean()}") ``` to often yield values like ``` Mask mean: 0.0625 Capacity mask mean: 0.0625 Mask (after AND) mean: 0.005908316932618618 ``` and in turn the average number of routed experts per token was as low as `0.001`. --------- Signed-off-by: Daniel Shen <dandanshen2002@gmail.com>
Author
Parents
Loading