[FSDP] Fix `_mp_shard` `record_stream()` (#91096)
IIUC, I dropped a needed `record_stream` call in https://github.com/pytorch/pytorch/pull/83665. I think this was because my original version of the PR retired the pre-unshard stream, but after some quantitative investigation, I brought it back.
- We allocate the `_mp_shard` in the pre-unshard stream.
https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/_runtime_utils.py#L260-L263
- For sharded strategies, we consume the `_mp_shard` only in the unshard stream (for all-gather).
https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/_runtime_utils.py#L270-L273
https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/flat_param.py#L1005-L1006
- For `NO_SHARD`, we consume the `_mp_shard` in the the unshard stream (for views) and in the default stream (for computation).
https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/_runtime_utils.py#L304
https://github.com/pytorch/pytorch/blob/731f417f60bfd5bb8d2ec756c23c0e6624ea3351/torch/distributed/fsdp/flat_param.py#L1256-L1261
- We must call `record_stream(_mp_shard, current_stream)` when freeing so that the allocator knows about the usage in the current stream.
- For sharded strategies, the free happens in `post_unshard()`, which runs in the unshard stream.
- For `NO_SHARD`, the free happens in `post_reshard()`, which runs in the default stream.
- Conveniently, for both, the current stream is the correct stream to synchronize. For `NO_SHARD`, the default stream waits for the unshard stream, so only recording in the default stream should suffice.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91096
Approved by: https://github.com/rohan-varma