[FSDP] Fix "use-after-free" in reshard logic (#94859)
**Overview**
This PR switches the order of freeing the unsharded `FlatParameter` (`self._free_unsharded_flat_param()`) and switching to use the sharded `FlatParameter` (`self._use_sharded_flat_param()`). This is to prevent "use-after_free"-type bugs where for `param.data = new_data`, `param` has its metadata intact but not its storage, causing an illegal memory access for any instrumentation that depends on its storage. (`param` is an original parameter and `new_data` is either a view into the sharded `FlatParameter` or `torch.empty(0)` depending on the sharding and rank.)
**Details**
To see why simply switching the order of the two calls is safe, let us examine the calls themselves:
https://github.com/pytorch/pytorch/blob/652457b1b738f710679b414fe4626d08c9a9e0db/torch/distributed/fsdp/flat_param.py#L1312-L1339
https://github.com/pytorch/pytorch/blob/652457b1b738f710679b414fe4626d08c9a9e0db/torch/distributed/fsdp/flat_param.py#L1298-L1310
- `_free_unsharded_flat_param()` does not make any assumption that `self.flat_param`'s data is the sharded `FlatParameter` (i.e. `_local_shard`).
- The sharded `FlatParameter` (i.e. `_local_shard`) is always present in memory, which means that FSDP can use sharded views at any time, including before freeing the unsharded data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94859
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin