fix uneven issue & add balance autotp (#4697)
This PR aims to balance the shard size of each worker as even as
possible.
1. We refactor the tp_shard logic that can make AutoTP work when
split_shape % num_kv_heads != 0.
2. When num_kv_heads is defined, the attention module relies on it to
sharding, but the mlp and lm_head modules can use near even division to
get more balance shard. It will get better performance.
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>