pytorch
9f3c6b1b - Fix graph break in a common `func(self, *args)` pattern (Faster stable diffusion) (#100444)

Commit
1 year ago
Fix graph break in a common `func(self, *args)` pattern (Faster stable diffusion) (#100444) Stable Diffusion has a pattern like this: ``` def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) ``` Wherein processor is something like `AttnProcessor2_0`, which is callable but not an NNModule. This allows for a significant speedup in stable diffusion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100444 Approved by: https://github.com/anijain2305
Author
Committer
Parents
Loading