[FSDP] Fix skip-sharded-views + mixed precision (#105346)
This fixes https://github.com/pytorch/pytorch/issues/104504.
- When not using full-precision eval, the relevant fix is to force `_use_sharded_views()` calls if needed in `SUMMON_FULL_PARAMS` training state.
- When using full-precision in eval, the relevant fix is tracking what was the unsharded flat parameter from which the unsharded views were computed and using that instead of determining the unsharded flat parameter from the calling context via `_get_padded_unsharded_flat_param()`.
This also fixes https://github.com/pytorch/pytorch/issues/104770.
<details>
<summary> Print output showing parity </summary>
```
Key: 0
Model 1: [-1.5, 6.40625, -0.9453125, -0.3828125, 0.16015625, -1.5078125]
Model 2: [-1.5, 6.40625, -0.9453125, -0.3828125, 0.16015625, -1.5078125]
Key: 1
Model 1: [0.0157470703125, -0.8828125, 5.65625, 1.1328125, 0.275390625, 0.11181640625]
Model 2: [0.0157470703125, -0.8828125, 5.65625, 1.1328125, 0.275390625, 0.11181640625]
Key: 2
Model 1: [0.1689453125, -0.00567626953125, -0.09375, 7.34375, -0.18359375, -0.09521484375]
Model 2: [0.1689453125, -0.00567626953125, -0.09375, 7.34375, -0.18359375, -0.09521484375]
Key: 3
Model 1: [0.546875, -0.8984375, 0.228515625, 0.7578125, 6.0625, 0.435546875]
Model 2: [0.546875, -0.8984375, 0.228515625, 0.7578125, 6.0625, 0.435546875]
Key: 4
Model 1: [-0.66796875, -0.88671875, 0.30078125, 0.06494140625, 0.412109375, 6.9375]
Model 2: [-0.66796875, -0.88671875, 0.30078125, 0.06494140625, 0.412109375, 6.9375]
Key: 5
Model 1: [0.07763671875, 0.8671875, -0.43359375, 0.5703125, 0.76171875, -0.0089111328125]
Model 2: [0.07763671875, 0.8671875, -0.43359375, 0.5703125, 0.76171875, -0.0089111328125]
Key: 6
Model 1: [-0.283203125, -0.361328125, 0.474609375, 0.10205078125, 1.125, -0.0859375]
Model 2: [-0.283203125, -0.361328125, 0.474609375, 0.10205078125, 1.125, -0.0859375]
Key: 7
Model 1: [1.140625, 0.62890625, -0.07568359375, -1.0390625, -0.2578125, -0.053955078125]
Model 2: [1.140625, 0.62890625, -0.07568359375, -1.0390625, -0.2578125, -0.053955078125]
Key: 8
Model 1: [0.68359375, -1.09375, 0.59375, 1.0, -0.23828125, 0.578125]
Model 2: [0.68359375, -1.09375, 0.59375, 1.0, -0.23828125, 0.578125]
Key: 9
Model 1: [0.515625, 0.296875, -0.1826171875, -0.12890625, -0.51953125, -0.3359375]
Model 2: [0.515625, 0.296875, -0.1826171875, -0.12890625, -0.51953125, -0.3359375]
```
</details>
Follow-ups:
- I suspect that for `SHARD_GRAD_OP`, train forward -> eval forward when using full-precision in eval will not free the low-precision unsharded parameters from the train forward, resulting in 1.5x unsharded parameter memory.
Differential Revision: [D47527597](https://our.internmc.facebook.com/intern/diff/D47527597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105346
Approved by: https://github.com/fegin, https://github.com/rohan-varma