Refactor autograd discovery code (#52057)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/34067 by using https://github.com/pytorch/pytorch/issues/34426 by hczhu
In addition to removing the unnecessary any() we do also:
- Get rid of the outer loop since graph_root also needs to be checked
- Update psuedo code description so it matches what the code does
- Add some comments explaining the difference between assigning `info.needed_` and `info.captures_` in terms of how that affects discovery
- [edit: another benefit is that exec_info entries are no longer created for all reachable nodes]
This PR is on top of https://github.com/pytorch/pytorch/issues/51940, so once that lands rebasing on top of master should get rid of the extra commits and changes
I'm not sure if this change will bring a lot of performance gains, but the main benefit is that the code is easier to read.
Trivial graph:
```
torch.autograd.grad(a*b, [a, b], gradient)
setup:
a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
gradient = torch.ones(2, 2)
Timer before:
15.45 us
Time after:
14.33 us
1 measurement, 10000 runs , 1 thread
Instructions after:
All Noisy symbols removed
Instructions: 8271213 8193169
Baseline: 4244 3838
Instructions before:
All Noisy symbols removed
Instructions: 8142843 8054463
Baseline: 4280 3838
100 runs per measurement, 1 thread
```
Small graph:
```
torch.autograd.grad((b*a.exp()+a*b.exp()).sum(), (a, b))
setup:
a = torch.rand((2, 2), requires_grad=True)
b = torch.rand((2, 2), requires_grad=True)
Time before:
52.25 us
Time after:
50.80 us
1 measurement, 10000 runs , 1 thread
Instruction count before:
All Noisy symbols removed
Instructions: 25601257 25518229
Baseline: 4228 3838
Instruction count after:
All Noisy symbols removed
Instructions: 25606533 25522797
Baseline: 4228
100 runs per measurement, 1 thread
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52057
Reviewed By: ngimel
Differential Revision: D26432207
Pulled By: soulitzer
fbshipit-source-id: beef68344d66e9e286378e31e3311ba43c25c749