Add print statements to debug sharding error (#102713)
sharding on rocm is broken, i cant replicate on dummy PRs even though it seems to happen pretty often on main, so adding this to increase my sample size. Hopefully this is enough print statements...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102713
Approved by: https://github.com/huydhn