[PyTorch][Distributed] Use auto-grad enabled collections for the shared linear op to enable backward grad calculation (#68096)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68096
We replace all c10d APIs with the Auto-grad collection in the shareded linear op. So that we can enable the backward propagation (grad calculation for sharded linear).
ghstack-source-id: 144882914
Test Plan: Unit test + CI
Reviewed By: pritamdamania87
Differential Revision: D32177341
fbshipit-source-id: 1919e8ca877bdc79f4cdb0dc2a82ddaf6881b9f1