[FSDP] Getting the parameter execution information using torch.fx (#80294)
This support allows one to get the parameter execution order at the FSDP module construction time rather than at runtime for some models, under which case the preparation step can be removed. This will be used in the non-recursive wrapping policy later on.
Note that this support is based on the assumption that the tracer provided by the user will be able to successfully trace the forward pass.
### Advantage of using `torch.fx` to get the parameter order rather than using backward hook:
When using backward hook, the parameter execution order will be the reversed ordering of the parameter gradient ready order. One problem is that we are not able to get the number of times each parameter is used inside the forward function. For example, consider the following forward function,
```python
def forward(self, x):
z = self.relu(self.layer0(x))
z = self.relu(self.layer2(z))
z = self.relu(self.layer1(z))
z = self.relu(self.layer0(x))
return z
```
Based on the parameter gradient ready order, the current parameter execution order for the example is `[layer0.weight, layer2.weight, layer1.weight]`. However, we don't get the information that layer0 is called twice.
Using `torch.fx`, we can get a more detailed parameter execution order: [layer0.weight, layer2.weight, layer1.weight, layer0.weight]. This allows us to implement more scheduling algorithms that could be useful in multiple regimes. For example, since we know that `layer0` will be called twice, we can delay the resharding of `layer0.weight` to the end, which would costs more memory but faster.
### Example of API usage
The execution information is recorded via calling `tracer.trace` in the `_patch_tracer` context manager:
```python
tracer = torch.fx.Tracer() # or an instance of Tracer's children class
execution_info = _init_execution_info(model)
with _patch_tracer(
tracer=tracer, root_module=model, execution_info=execution_info
):
tracer.trace(model, concrete_args=...)
```
The execution information will be recorded in `execution_info.module_forward_order` and `execution_info.module_to_execution_infos`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80294
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao