Fix `shard_module` to appropriately deal with sub process groups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79264
`shard_module` API didn't work correctly with a sub-pg since
`dist.scatter` actually takes the global rank as input for `src`.
Fixing this by passing in the appropriate rank to `dist.scatter`
Differential Revision: [D37062766](https://our.internmc.facebook.com/intern/diff/D37062766/)
Approved by: https://github.com/fduwjj, https://github.com/wanchaol