pytorch
4739d15a - Skip some nodes during discovery using sequence number (#52180)

Commit
3 years ago
Skip some nodes during discovery using sequence number (#52180) Summary: Fixes https://github.com/pytorch/pytorch/issues/12635 This change will help us speed up autograd's discovery algorithm in cases where we use `.grad` and we try to "unroll" the training loop. For example the example in the issue and also https://github.com/pytorch/pytorch/pull/52180#issuecomment-783400832 observe an unbounded multiple of speed-up. We do this by adding a new sequence_nr-type numbering: for each node, we maintain the length of the longest path from it to any leaf node. How does this help us speed up discovery (dfs)? Previously the bottleneck was that the dfs that computes which nodes need to be executed always explored every node. With this change, before we run dfs, we first compute the mininum seq_nr among all the nodes passed as the `inputs`. If let this be some number N, intuitively this means that dfs should stay at least N units away from any leaf node. So, if we find ourselves too close to any leaf node, we should stop our search early. Edit: After some discussion offline, the plan is: - make old sequence_nr a construct of the profiler. This means we can avoid accessing thread local state in cases where the profiler is disabled. Note that we cannot replace sequence_nr as-is because profiler's use-case requires that thread-id + sequence_nr can uniquely identify a given node in order for downstream users/programs to correlate nodes from backward and forward passes. This means we must maintain two sequence_nr's and that we have an extra field in Node. - In a future PR, we can potentially remove sequence_nr entirely from the profiler as well, but we avoid doing it now because we haven't measured, and its a larger effort because we'd have to mess around with the dispatcher and profiler Testing with this [code](https://gist.github.com/kyunghyuncho/5fb9991ce1233f909051854a84b7148e), we see that runtime no longer increases as we iterate. Before: ``` 100: Time taken: 0.47s, loss: 1.1e+06 200: Time taken: 0.064s, loss: 6.5e+05 300: Time taken: 0.088s, loss: 4.4e+05 400: Time taken: 0.1s, loss: 3.2e+05 500: Time taken: 0.12s, loss: 2.5e+05 600: Time taken: 0.15s, loss: 2e+05 700: Time taken: 0.18s, loss: 1.7e+05 800: Time taken: 0.2s, loss: 1.4e+05 900: Time taken: 0.22s, loss: 1.2e+05 1000: Time taken: 0.24s, loss: 1.1e+05 1100: Time taken: 0.27s, loss: 9.3e+04 1200: Time taken: 0.3s, loss: 8.3e+04 1300: Time taken: 0.34s, loss: 7.4e+04 1400: Time taken: 0.36s, loss: 6.7e+04 1500: Time taken: 0.38s, loss: 6.1e+04 1600: Time taken: 0.4s, loss: 5.6e+04 1700: Time taken: 0.42s, loss: 5.1e+04 1800: Time taken: 0.44s, loss: 4.7e+04 1900: Time taken: 0.47s, loss: 4.4e+04 2000: Time taken: 0.5s, loss: 4.1e+04 ``` After: ``` 100: Time taken: 0.49s, loss: 1.2e+06 200: Time taken: 0.031s, loss: 6.9e+05 300: Time taken: 0.031s, loss: 4.6e+05 400: Time taken: 0.031s, loss: 3.3e+05 500: Time taken: 0.031s, loss: 2.6e+05 600: Time taken: 0.031s, loss: 2.1e+05 700: Time taken: 0.031s, loss: 1.7e+05 800: Time taken: 0.031s, loss: 1.4e+05 900: Time taken: 0.031s, loss: 1.2e+05 1000: Time taken: 0.031s, loss: 1.1e+05 1100: Time taken: 0.031s, loss: 9.6e+04 1200: Time taken: 0.031s, loss: 8.6e+04 1300: Time taken: 0.031s, loss: 7.7e+04 1400: Time taken: 0.031s, loss: 7e+04 1500: Time taken: 0.031s, loss: 6.3e+04 1600: Time taken: 0.031s, loss: 5.8e+04 1700: Time taken: 0.031s, loss: 5.3e+04 1800: Time taken: 0.031s, loss: 4.9e+04 1900: Time taken: 0.031s, loss: 4.5e+04 2000: Time taken: 0.032s, loss: 4.2e+04 ``` Testing w/ small graph to check for regression: ``` import torch from torch.utils.benchmark import Timer setup=""" a = torch.rand((2, 2), requires_grad=True) b = torch.rand((2, 2), requires_grad=True) gradient = torch.ones(2, 2) """ stmt=""" torch.autograd.grad(a*b, [a, b], gradient) """ timer = Timer(stmt, setup) print(timer.timeit(10000)) print(timer.collect_callgrind(100)) ``` Result: there doesn't seem to be any significant regression ``` Time before: 12.74 us Time after: 13.12 us Instruction count before: All Noisy symbols removed Instructions: 8078960 8000882 Baseline: 4226 3838 Instruction count after: All Noisy symbols removed Instructions: 8091846 8017940 Baseline: 4336 3838 100 runs per measurement, 1 thread ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/52180 Reviewed By: gchanan, zhangguanheng66 Differential Revision: D26794387 Pulled By: soulitzer fbshipit-source-id: c00d387a29f151109c33dc6f1b56a8f275cdec58
Author
Parents
Loading