pytorch
8da776e3 - [FSDP] Fix "use-after-free" in reshard logic (#94859)

Commit
2 years ago
[FSDP] Fix "use-after-free" in reshard logic (#94859) **Overview** This PR switches the order of freeing the unsharded `FlatParameter` (`self._free_unsharded_flat_param()`) and switching to use the sharded `FlatParameter` (`self._use_sharded_flat_param()`). This is to prevent "use-after_free"-type bugs where for `param.data = new_data`, `param` has its metadata intact but not its storage, causing an illegal memory access for any instrumentation that depends on its storage. (`param` is an original parameter and `new_data` is either a view into the sharded `FlatParameter` or `torch.empty(0)` depending on the sharding and rank.) **Details** To see why simply switching the order of the two calls is safe, let us examine the calls themselves: https://github.com/pytorch/pytorch/blob/652457b1b738f710679b414fe4626d08c9a9e0db/torch/distributed/fsdp/flat_param.py#L1312-L1339 https://github.com/pytorch/pytorch/blob/652457b1b738f710679b414fe4626d08c9a9e0db/torch/distributed/fsdp/flat_param.py#L1298-L1310 - `_free_unsharded_flat_param()` does not make any assumption that `self.flat_param`'s data is the sharded `FlatParameter` (i.e. `_local_shard`). - The sharded `FlatParameter` (i.e. `_local_shard`) is always present in memory, which means that FSDP can use sharded views at any time, including before freeing the unsharded data. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94859 Approved by: https://github.com/zhaojuanmao, https://github.com/fegin
Author
Andrew Gu
Committer
Parents
Loading