[FSDP] Use correct handle training state when prefetching (#98249)
This PR ensures that when prefetching a `FlatParamHandle.unshard()`, we temporarily set the `FlatParamHandle._training_state` to the expected training state as if the `unshard()` were not prefetched since the `as_params` argument to `_use_unsharded_views()` depends on the handle's training state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98249
Approved by: https://github.com/rohan-varma