pytorch
0f7e60d6 - [shard] add ShardedTensor.cpu() (#74941)

Commit
2 years ago
[shard] add ShardedTensor.cpu() (#74941) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74941 Add ShardedTensor.cpu() API, we choose to add this first instead of `ShardedTensor.to()` bc the latter one is ambiguous and tricker (i.e. we need to figure out what's the meaning of ShardedTensor.to("cuda")). Note that when moving ShardedTensor to CPU from GPU, user need to explicitly pass in a new process_group that is compatible with the device, otherwise we will error out. ghstack-source-id: 153455401 Test Plan: test_sharded_tensor_cpu Reviewed By: pritamdamania87 Differential Revision: D35238133 fbshipit-source-id: 71921942f9295384f4e0eec2bfcdcff63a3d047c (cherry picked from commit a7d97f70f81648f0f44cd77b4b064b4967aaa2d5)
Author
Committer
Parents
Loading