Set ThreadLocalState correctly in the autograd engine (#56174)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56174
evaluate_function:
1. calls the autograd function (call_function)
2. accumulates gradients into buffers
Previously, ThreadLocalStateGuard only covered part of `call_function`.
However, it should cover all Tensor operations in `evaluate_function`,
so this PR moves it to do so.
One alternative would have been to move ThreadLocalStateGuard to here:
https://github.com/pytorch/pytorch/blob/71f9e99e293d0eff8da665b69543d044a6a4454d/torch/csrc/autograd/engine.cpp#L394
Unfortunately that adds 2% additional instructions according to the
instruction count benchmark in the next section. This is because
`evaluate_function` does an early return:
https://github.com/pytorch/pytorch/blob/71f9e99e293d0eff8da665b69543d044a6a4454d/torch/csrc/autograd/engine.cpp#L732-L735
If this is preferred, please let me know.
Test Plan:
- run existing tests. It's hard to actually come up with a test case for
this.
Benchmark plan:
TL;DR: Instruction count decreases by a little after this PR.
```
import torch
from torch.utils.benchmark import Timer
timer = Timer(
stmt="""\
torch::autograd::grad({y}, {x}, {}, /*retain_grad=*/true);""",
setup="""\
auto x = torch::ones({}, torch::requires_grad());
auto y = x * 2;""",
language="cpp")
stats = timer.collect_callgrind()
print(stats)
```
This gave the following:
```
Before:
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f4b28ce6a90>
torch::autograd::grad({y}, {x}, {}, /*retain_grad=*/true);
setup:
auto x = torch::ones({}, torch::requires_grad());
auto y = x * 2;
All Noisy symbols removed
Instructions: 3514184 3514184
Baseline: 0 0
100 runs per measurement, 1 thread
After:
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fdbc9d187d0>
torch::autograd::grad({y}, {x}, {}, /*retain_grad=*/true);
setup:
auto x = torch::ones({}, torch::requires_grad());
auto y = x * 2;
All Noisy symbols removed
Instructions: 3513884 3513884
Baseline: 0 0
100 runs per measurement, 1 thread
```
Reviewed By: albanD
Differential Revision: D27799283
Pulled By: zou3519
fbshipit-source-id: 0a8213824e08c04748d38e66604c73f395285d63