pytorch
c3457550 - [FSDP] Fix `_mp_shard` `record_stream()` (#91096)

Commit
2 years ago
[FSDP] Fix `_mp_shard` `record_stream()` (#91096) IIUC, I dropped a needed `record_stream` call in https://github.com/pytorch/pytorch/pull/83665. I think this was because my original version of the PR retired the pre-unshard stream, but after some quantitative investigation, I brought it back. - We allocate the `_mp_shard` in the pre-unshard stream. https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/_runtime_utils.py#L260-L263 - For sharded strategies, we consume the `_mp_shard` only in the unshard stream (for all-gather). https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/_runtime_utils.py#L270-L273 https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/flat_param.py#L1005-L1006 - For `NO_SHARD`, we consume the `_mp_shard` in the the unshard stream (for views) and in the default stream (for computation). https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/_runtime_utils.py#L304 https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/flat_param.py#L1256-L1261 - We must call `record_stream(_mp_shard, current_stream)` when freeing so that the allocator knows about the usage in the current stream. - For sharded strategies, the free happens in `post_unshard()`, which runs in the unshard stream. - For `NO_SHARD`, the free happens in `post_reshard()`, which runs in the default stream. - Conveniently, for both, the current stream is the correct stream to synchronize. For `NO_SHARD`, the default stream waits for the unshard stream, so only recording in the default stream should suffice. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91096 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading