pytorch
d2d11125 - Set ThreadLocalState correctly in the autograd engine (#56174)

Commit
3 years ago
Set ThreadLocalState correctly in the autograd engine (#56174) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56174 evaluate_function: 1. calls the autograd function (call_function) 2. accumulates gradients into buffers Previously, ThreadLocalStateGuard only covered part of `call_function`. However, it should cover all Tensor operations in `evaluate_function`, so this PR moves it to do so. One alternative would have been to move ThreadLocalStateGuard to here: https://github.com/pytorch/pytorch/blob/71f9e99e293d0eff8da665b69543d044a6a4454d/torch/csrc/autograd/engine.cpp#L394 Unfortunately that adds 2% additional instructions according to the instruction count benchmark in the next section. This is because `evaluate_function` does an early return: https://github.com/pytorch/pytorch/blob/71f9e99e293d0eff8da665b69543d044a6a4454d/torch/csrc/autograd/engine.cpp#L732-L735 If this is preferred, please let me know. Test Plan: - run existing tests. It's hard to actually come up with a test case for this. Benchmark plan: TL;DR: Instruction count decreases by a little after this PR. ``` import torch from torch.utils.benchmark import Timer timer = Timer( stmt="""\ torch::autograd::grad({y}, {x}, {}, /*retain_grad=*/true);""", setup="""\ auto x = torch::ones({}, torch::requires_grad()); auto y = x * 2;""", language="cpp") stats = timer.collect_callgrind() print(stats) ``` This gave the following: ``` Before: <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f4b28ce6a90> torch::autograd::grad({y}, {x}, {}, /*retain_grad=*/true); setup: auto x = torch::ones({}, torch::requires_grad()); auto y = x * 2; All Noisy symbols removed Instructions: 3514184 3514184 Baseline: 0 0 100 runs per measurement, 1 thread After: <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fdbc9d187d0> torch::autograd::grad({y}, {x}, {}, /*retain_grad=*/true); setup: auto x = torch::ones({}, torch::requires_grad()); auto y = x * 2; All Noisy symbols removed Instructions: 3513884 3513884 Baseline: 0 0 100 runs per measurement, 1 thread ``` Reviewed By: albanD Differential Revision: D27799283 Pulled By: zou3519 fbshipit-source-id: 0a8213824e08c04748d38e66604c73f395285d63
Author
Parents
Loading