MOE gate fixes and enhancements (#5156)
Fixes the following issues:
- Fix capacity when using TP for non-MoE by aligning the capacity to TP
- Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16
- Fix top2 aux loss to be similar to top1 aux loss
Following are few configurable enhancements:
- Support top2 with disable token dropping
- Support disable top2 2nd expert sampling
---------
Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>