Fix graph break in a common `func(self, *args)` pattern (Faster stable diffusion) (#100444)
Stable Diffusion has a pattern like this:
```
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
```
Wherein processor is something like `AttnProcessor2_0`, which is callable but not an NNModule.
This allows for a significant speedup in stable diffusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100444
Approved by: https://github.com/anijain2305