Fix uneven head sequence parallelism bug (#6774) (#6797)
Here `gather_idx < 2` represents `is_first_all2all`. During the first
all2all, `uneven_head_all2all` will be called if either `num_heads %
seq_world_size != 0` or `get_num_kv_heads() is None`.
During the second all2all, it'll return return `uneven_head_all2all` if
and only if `get_num_kv_heads() is None` which is always set during the
first uneven all2all. This means that there will no longer be issue
where `uneven_head_all2all ` is returned for the second all2all because
of `num_heads % seq_world_size != 0`.
Fixes: #6774
---------
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>