pytorch
9b5d9b40 - Remove `_broadcast_object()` from `ZeroRedundancyOptimizer` (#61539)

Commit
4 years ago
Remove `_broadcast_object()` from `ZeroRedundancyOptimizer` (#61539) Summary: Revised version of https://github.com/pytorch/pytorch/issues/60573. **Overview:** This makes two changes: - It introduces a `map_location` argument to `broadcast_object_list()`. The argument specifies the device to load tensors contained in objects received from the broadcast. This change requires modifying the implementation of `_object_to_tensor()` and `_tensor_to_object()` to use `torch.save()` and torch.load()` respectively. - It removes all calls to `_broadcast_object()` in `ZeroRedundancyOptimizer` and the corresponding test file in favor of `broadcast_object_list()`. The default value of `map_location` is `None`, in which case `_object_to_tensor()` and hence `broadcast_object_list()` preserve their original behavior. Namely, contained tensors are loaded to their original device. In `consolidate_state_dict()`, I specify `map_location=torch.device("cpu")` instead of `self._default_device`. This slightly changes the behavior from before when using `_broadcast_object()`. The reason I do so is that it saves one GPU to CPU data transfer since the action immediately after receiving the broadcasted `local_state_dict` is to copy it to CPU. Explicitly, if `map_location=self._default_device`, then the data transfer path assuming NCCL backend is as follows: `source GPU --[before serialize]--> source CPU --[before broadcast]--> source GPU --[broadcast]--> destination GPU --[before deserialize]--> destination CPU --[deserialize]--> destination GPU --[copy]--> destination CPU` Hence, by setting `map_location=torch.device("cpu")` instead, the suffix becomes: `destination CPU --[deserialize]--> destination CPU --[copy]--> destination CPU` Pull Request resolved: https://github.com/pytorch/pytorch/pull/61539 Test Plan: I added a test `test_broadcast_object_list_map_location()` that checks for both `map_location` as CPU and GPU that (1) tensors contained in broadcasted objects are appropriately loaded onto the specified device and (2) that the contents of the tensors are correct. The existing `ZeroRedundancyOptimizer` tests pass. ``` gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py ``` The existing `broadcast_object_list()` test passes: ``` touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py -- TestDistBackendWithFork.test_broadcast_object_list ``` Reviewed By: zou3519 Differential Revision: D29701479 Pulled By: andwgu fbshipit-source-id: c8d5f9057b32e5e9f40e8edc5b2cc25fb21414a9
Author
Andrew Gu
Parents
Loading