[PT-D] Fix Sharding spec inference to avoid invalid chunk sharding to be inferred as chunkshardingspec (#75296)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75296
Our previous logic does not fully all corner cases when it comes to imbalance sharding. For example, one is 16 another one is 9. (This cannot be chunk sharding) because the total length is 25 and if it's chunk sharding, it should be 13 and 12. So we need to get the total length first and calculated the expected chunk length to ensure it's indeed a chunk sharding case.
Also added more test cases in unit test.
ghstack-source-id: 153190671
Test Plan: CI
Reviewed By: fegin
Differential Revision: D35417257
fbshipit-source-id: a7df2183f9747c765498eb460678709b76cdf7b4
(cherry picked from commit 730450c7398e49e2615b174505bd9854cfc668b5)