pytorch
77df2ca9 - Special-case fsdp wrapped modules to be Unspecialized (#89330)

Commit
2 years ago
Special-case fsdp wrapped modules to be Unspecialized (#89330) ### Summary Making dynamo treat the nn.Modules inside FSDP wrappers as 'Unspecialized' results in dynamo-produced graphs where nn.module parameters are inputs to the graph rather than attributes of the outer graphmodule. This helps in FSDP since it forces dynamo to pick the latest copy of the parameters off the user's nn.Module (which FSDP mutates every pre_forward), solving the ordering issue in backward. ### Details Imagine this toy model ``` class MyModule(torch.nn.Module): def __init__(self, a, b): super(MyModule, self).__init__() self.net = nn.Sequential( nn.Linear(a, b), nn.ReLU(), ) def forward(self, x): return self.net(x) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net = nn.Sequential( *[MyModule(10, 10000)] + [MyModule(10000, 1000)] + [MyModule(1000, 5)] ) def forward(self, x): return self.net(x) ``` Where FSDP is recursively wrapped around each `MyModule`, then dynamo-compiled, with dynamo already configured to skip/break in FSDP code. You'd expect to get 3 compiled AOT functions, corresponding to the contents of `MyModule`, and then see FSDP's communication ops happen inbetween them (eagerly). This almost happens (everything works out fine in forward), but in backward there is an ordering issue. FSDP creates a flat buffer for all the parameters that are bucketed together, and then creates views into this buffer to replace the original parameters. On each iteration of forward, it creates a new view after 'filling' the flatbuffer with data from an all-gather operation, to 'unshard' the parameters from remote devices. Dynamo traces the first such view and stores it in a compiled graphmodule. During tracing, we see (1) view created for first MyModule, (2) compile first MyModule, (3) ... for the rest of layers Then during runtime, we see (A) view created for first MyModule (and orphaned), (B) execute first compiled MyModule, using old view, ... This is a problem, because we want backward hooks to run right after each compiled-backward, but autograd executes those hooks in an order mirroring their execution order during forward. Since we are forever using the views created during steps (1, 3, .. N), which all happen before the steps (A, B, ...), this means that all the hooks will happen after all the compiled backwards. An illustration of the problem - a torchviz graph showing the 2 possible orderings of autograd, and a profile showing the view-backwards ops happening after all the compiled backwards, and before all the backward hooks. <img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828002-32dbbd15-8fc3-4281-93e9-227ab5e32683.png"> <img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828632-33e40729-9a7f-4e68-9ce1-571e3a8dd2dd.png"> A solution is to make dynamo not specialize on these nn modules. It is worth pointing out that this nn.module specialization is de-facto failing, as we are modifying .parameters and this bypasses dynamo's __setattr__ monkeypatch, which should have automatically kicked us out to Unspecialized and forced a recompile. After unspecializing, the new views (created during steps A, C, ...) are actually _used_ at runtime by the module, making their creation order interleaved, making autograd execute their backwards interleaved. The new torchviz graph (this time with names added for the view tensors): <img width="2043" alt="image" src="https://user-images.githubusercontent.com/4984825/202828480-d30005ba-0d20-45d8-b647-30b7ff5e91d3.png"> And a new profile showing the interleaving of compiled backwards and hooks, allowing overlapping of reduce-scatter. <img width="2293" alt="image" src="https://user-images.githubusercontent.com/4984825/202828533-bb20a041-19b8-499c-b3cf-02808933df47.png"> @jansel @davidberard98 @aazzolini @mrshenli @awgu @ezyang @soumith @voznesenskym @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89330 Approved by: https://github.com/davidberard98
Author
Committer
Parents
Loading