Add flop counter utility (#95751)
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode
inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()
flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">
You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.
<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">
## Other APIs
FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95751
Approved by: https://github.com/ngimel, https://github.com/albanD