Enhance Composable FSDP cast forward input tests (#100349)
The fix for https://github.com/pytorch/pytorch/pull/99545 (https://github.com/pytorch/pytorch/pull/99546) explicitly required users to set `cast_forward_inputs=False` if they wanted to avoid hitting #99545 while using an FSDP root module with no direct parameters.
After further consideration, [the team believes](https://github.com/pytorch/pytorch/pull/99546#discussion_r1180898687) it is sufficiently common for the default `cast_forward_inputs=False` to be used with a FSDP root module possessing no direct parameters that a solution to #99545 that accommodates this use case is desired.
This PR builds on @zhaojuanmao's https://github.com/pytorch/pytorch/pull/100290 (nice!) to enhance the FSDP cast forward inputs testing to include a broader range of scenarios and to include `model.eval()` testing as well as training mode validation. (I unfortunately don't have permissions that would allow me to use ghstack directly but I can rebase this PR however the team desires, once #100290 lands etc.)
Currently, the evaluation mode testing is commented out while the team decides on the best approach to implementing the broader solution to https://github.com/pytorch/pytorch/pull/99545. Once an implementation is decided, the evaluation mode validation function in the new tests added in this PR can be uncommented and should continue to pass. I also include one potential evaluation mode solution suggestion in this PR but leave the existing code unchanged since I know the team is intending to consider a range of solutions this week.
Test notes:
1. The 8 tests added here are a superset of the current `test_float16_on_one_submodule` tests, including validation of the following configurations: (`cast_root_forward_inputs_submodule` = True/False, `cast_forward_inputs_submodule` = True/False, `use_root_no_params` = True/False) across both training and evaluation modes.
2. The `float16_on_one_submodule` model configuration is currently only tested in the FSDP root module with parameters scenarios (as was the existing case) but this test can be easily extended to test it in the FSDP root module with no parameters scenarios as well if the team thinks the additional test resource usage is justified.
3. Since this test amortizes the cost of test setup across the aforementioned range of scenarios, the loop-based implementation of `dtype` validation (below) would have been undesirably complex IMHO[^1] :
```python
############### Logical equivalent of current test result matrix ############
if self.cast_root_forward_inputs_submodule or self.cast_forward_inputs_submodule:
self.assertEqual(self.forward_inputs[self.c2].dtype, torch.float16)
if use_root_no_params:
if self.cast_root_forward_inputs_submodule:
self.assertEqual(self.forward_inputs[self.model].dtype, torch.float16)
else:
self.assertEqual(self.forward_inputs[self.model].dtype, torch.float32)
self.assertEqual(self.forward_inputs[self.c1].dtype, torch.float16)
else:
self.assertEqual(self.forward_inputs[self.c1].dtype, torch.float32)
else:
self.assertEqual(self.forward_inputs[self.model].dtype, torch.float32)
self.assertEqual(self.forward_inputs[self.c1].dtype, torch.float32)
if not use_root_no_params: # this input will only exist in the root with params case until eval fix is applied
self.assertEqual(self.forward_inputs[self.c2].dtype, torch.float32)
```
so I implemented the validation function as an expected result lookup that provides the added benefit of explicitly specifying the failed subtest upon failed `dtype` assertions, e.g.:
```python
AssertionError: None mismatch: torch.float32 is not None
Subtest `no_cast_root_no_cast_child_no_root_params` failed.
```
The potential solution to https://github.com/pytorch/pytorch/pull/99545 that I added as a suggestion in the file conversation passes this test set but I know there are a lot of different ways that it could be resolved so I'll assume that change will be tackled in a separate PR unless the team wants to include it in this one.
As mentioned, I've currently based this PR off of https://github.com/pytorch/pytorch/pull/100290 so am happy to either wait for that to land first or rebase this PR however the team wants.
[^1]: Batching the scenarios into different tests is also possible of course but would involve unnecessary test setup overhead, happy to switch to that approach if the team prefers that though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100349
Approved by: https://github.com/awgu