[dynamo] profiler.record_function on all dynamo_timed functions (#96495)
**Summary**: profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that @dynamo_timed is annotated.
dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() are already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.
Why not just turn on python stack traces in the profiler to get this information? Dynamo compilation is implemented in python and therefore produces a huge amount of events when it records compilation steps. The resulting trace files are often too large to load in chrome://tracing, and they take a long time to generate. Additionally, the stack traces are deep enough that they are often hard to read. This approach produces much more readable traces with lower overhead.
**Tests**:
- Added in test/dynamo/test_profiler.py. Verified in https://github.com/pytorch/pytorch/actions/runs/4559322864/jobs/8043307798?pr=96495 that the tests are actually running.
- Performance run with `ciflow/inductor-perf-compare` shows no noticeable change in compilation time or speedup numbers. Geomean speedup changes from 1.275 -> 1.277. Geomean compilation times change from 54.2s -> 53.8s. That's likely just due to noise. All individual benchmark numbers regressed by no more than 5% between the two runs; and we see improvements of around the same magnitude, suggesting this is, again, just noise. For meta employees, you can see the results in a google sheets here: https://docs.google.com/spreadsheets/d/1Ki69XvcgxcA3ZnqC5n_jav5KiD4u7Wojlad3VTnIdlk/edit?usp=sharing
**Example**:
Run this:
```python
import torch
def gn(x):
return x.sin().cos()
def fn(x, y):
return x.sin() * y.cos()
x, y = [torch.rand((2, 2), device='cuda') for _ in range(2)]
# just to clear out any lazy initialization
with torch.profiler.profile() as prof:
torch.compile(gn)(x)
with torch.profiler.profile() as prof:
torch.compile(fn)(x, y)
prof.export_chrome_trace("./dynamo_timed_profile.json")
```
and we can see that the resulting trace shows important dynamo steps, even when python tracing is turned off.
<img width="867" alt="Screenshot 2023-03-29 at 7 26 15 PM" src="https://user-images.githubusercontent.com/5067123/228712263-8ae67ab9-1a52-4765-a9c2-7c5cf0abe2f5.png">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96495
Approved by: https://github.com/ngimel, https://github.com/mlazos