[MoE] Fix misuse of num_experts as expert parallel group size (ep_size) (#7551)
Fixes #7535
## Description
This PR fixes a bug in inference/engine.py where num_experts
(moe_experts) was incorrectly passed as the expert parallel group size
(ep_size) when creating expert parallel groups.
Currently:
```
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(config.moe.moe_experts)
```
This causes **invalid** behavior whenever `num_experts > world_size`,
because `_create_ep_parallel_group` expects a group size, not the total
number of experts as pointed out by @Arnoochka
## Root Cause
num_experts = number of experts inside the MoE layer.
ep_size = how many GPUs to group together for expert parallelism.
These were mixed up in the code.
##Fix
Replaced the incorrect call with the proper ep_size argument:
```
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(config.moe.ep_size)
```
Additionally, added a safety check in _create_ep_parallel_group to catch
invalid configurations:
```
num_ep_groups = dist.get_world_size() // moe_ep_size
if num_ep_groups == 0:
raise ValueError(
f"Invalid ep_size={moe_ep_size} for world_size={dist.get_world_size()}"
)
```
## Backward compatibility
- If a user was already running with ep_size >= num_experts, the old
code worked fine which would still work fine.
- Only the previously broken case (num_experts > world_size) now works
correctly.
Signed-off-by: Flakes342 <ayushtanwar1729@gmail.com>