[FSDP] Reduce CPU overhead (#96958)
I experimented with 200 `nn.Linear`s with `bias=True` for a total of 400 `nn.Parameter`s all wrapped into the same FSDP instance and world size of 2.
**`unshard()` -> `_use_unsharded_views()`**
- (From previous PR) unsafe `setattr`: 6.112 ms -> 4.268 ms
**`pre_unshard()` -> `_writeback_orig_params()`**
- Factor out `flat_param` and `flat_param_grad` data pointers: ~1.8 ms -> 1.071 ms
- Now dominated by calling `_typed_storage()` on each original parameter and its gradient
**`reshard()` -> `_use_sharded_views()`**
- Factor out `torch.empty(0, ...)`: ~4.6 - 4.7 ms -> ~2.7 - 2.8 ms
- Now dominated by `aten::slice()` and (unsafe) `setattr`, which are required
I removed some `assert` calls that were only needed for mypy or if the subsequent call would provide the same error message anyway. These have negligible overhead, but I think it is still okay to remove them and avoid the type check. We need to address type checking more holistically anyway.
---
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96958
Approved by: https://github.com/rohan-varma