Use `_all_gather_base` and fuse matmul for sharded linear.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78477
Use `_all_gather_base` instead of all_gather for col-wise sharding
since `_all_gather_base` returns a single fused tensor that can be used to
perform a single matmul instead of looping through and performing multiple
matmuls.
This improves performance for col-wise sharding.
Differential Revision: [D36754385](https://our.internmc.facebook.com/intern/diff/D36754385/)
Approved by: https://github.com/aazzolini, https://github.com/wanchaol