Fix memory leak in ShardedTensor. (#71445)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71445
A reference to the ShardedTensor was always added to the global map
`_sharded_tensor_map`, that never got cleaned up since the map always held a
reference to the ShardedTensor.
A couple of fixes for this:
1) Add to the global map only for `init_rrefs=True` since only this codepath
requires this.
2) Add a `weakref` to the global map to avoid having a reference to the
ShardedTensor forever that never gets cleaned up.
ghstack-source-id: 147299580
Test Plan: waitforbuildbot
Reviewed By: fduwjj
Differential Revision: D33641013
fbshipit-source-id: c552fa3359186514445fd5715bec93f67dc2262d
(cherry picked from commit d25f1a645313dcbf8c37158d80c42c983262cec2)