pytorch
f02e93b5 - jacrev : Support chunked computation (#89376)

Commit
2 years ago
jacrev : Support chunked computation (#89376) Ref: https://github.com/pytorch/functorch/issues/680 We introduce a kwarg `chunk_size` in `jacrev` to control whether the Jacobian computation should be chunked and if so then `chunk_size` will dictate the maximum size of the chunks used. We try two approaches, * Stacked Approach: Append the intermediate computation to a list and then stack those results. * Pre-allocation Approach: Pre-allocate a zeros tensor and copy chunked computation into it. For Memory Benchmark, see https://github.com/pytorch/pytorch/pull/89376#issuecomment-1348479098 Benchmark CPU : Performs better with more chunks/ smaller chunk_size. NOTE: There seems to be a lot of noise for shape `(64, 64)`. <details> ``` [----------------------------------------------- jacrev : device cpu : chunks 2 -----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: --------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 2080 | 76.2 | 50.9 | 80.1 (128, 128) : chunk_size 8256 | 1172.8 | 783.3 | 1225.5 (128, 144) : chunk_size 9288 | 1475.1 | 990.4 | 1548.3 (144, 144) : chunk_size 10440 | 1871.3 | 1254.4 | 1971.2 Times are in milliseconds (ms). [----------------------------------------------- jacrev : device cpu : chunks 3 ----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1386 | 39.9 | 25.8 | 58.8 (128, 128) : chunk_size 5504 | 1182.6 | 782.2 | 1229.7 (128, 144) : chunk_size 6192 | 1483.6 | 995.4 | 1550.6 (144, 144) : chunk_size 6960 | 1879.1 | 1257.7 | 1960.5 Times are in milliseconds (ms). [----------------------------------------------- jacrev : device cpu : chunks 4 ----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1040 | 41.7 | 50.6 | 29.1 (128, 128) : chunk_size 4128 | 1171.6 | 782.3 | 1226.7 (128, 144) : chunk_size 4644 | 1482.2 | 994.6 | 1550.9 (144, 144) : chunk_size 5220 | 1870.2 | 1254.5 | 1961.4 Times are in milliseconds (ms). [--------------------------------------------- jacrev : device cpu : chunks 100 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 41 | 46.8 | 50.5 | 46.4 (128, 128) : chunk_size 165 | 622.2 | 775.2 | 656.0 (128, 144) : chunk_size 185 | 803.9 | 987.3 | 866.9 (144, 144) : chunk_size 208 | 1021.1 | 1251.2 | 1088.2 Times are in milliseconds (ms). [--------------------------------------------- jacrev : device cpu : chunks 200 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 20 | 60.9 | 50.2 | 62.3 (128, 128) : chunk_size 82 | 583.1 | 779.4 | 634.3 (128, 144) : chunk_size 92 | 834.1 | 1005.8 | 472.3 (144, 144) : chunk_size 104 | 1053.6 | 1277.0 | 1033.9 Times are in milliseconds (ms). [--------------------------------------------- jacrev : device cpu : chunks 300 --------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------ (64, 64) : chunk_size 13 | 77.7 | 50.4 | 79.6 (128, 128) : chunk_size 55 | 578.9 | 782.3 | 626.9 (128, 144) : chunk_size 61 | 718.2 | 1024.9 | 800.4 (144, 144) : chunk_size 69 | 919.7 | 1313.7 | 1023.0 Times are in milliseconds (ms). ``` </details> Benchmark CUDA: Performs better with less chunks/bigger chunk_size. <details> ``` [--------------------------------------------- jacrev : device cuda:1 : chunks 2 ----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: --------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 2080 | 1485.7 | 923.8 | 1632.3 (128, 128) : chunk_size 8256 | 25390.2 | 14103.2 | 33557.4 (128, 144) : chunk_size 9288 | 801.7 | 16854.1 | 42894.6 (144, 144) : chunk_size 10440 | 1003.5 | 21386.5 | 59648.5 Times are in microseconds (us). 3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 3 [--------------------------------------------- jacrev : device cuda:1 : chunks 3 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1386 | 1474.5 | 924.5 | 1655.5 (128, 128) : chunk_size 5504 | 25368.9 | 10156.0 | 34022.1 (128, 144) : chunk_size 6192 | 25223.0 | 12933.7 | 56418.5 (144, 144) : chunk_size 6960 | 24729.3 | 16367.4 | 68744.7 Times are in microseconds (us). 3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 4 [--------------------------------------------- jacrev : device cuda:1 : chunks 4 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1040 | 1489.2 | 924.4 | 1679.6 (128, 128) : chunk_size 4128 | 25370.4 | 8987.4 | 57201.3 (128, 144) : chunk_size 4644 | 32239.1 | 10136.2 | 72406.5 (144, 144) : chunk_size 5220 | 40994.3 | 12867.8 | 108653.4 Times are in microseconds (us). 3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 100 [------------------------------------------- jacrev : device cuda:1 : chunks 100 --------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 41 | 21121.8 | 924.2 | 22753.5 (128, 128) : chunk_size 165 | 23679.7 | 14284.4 | 26758.2 (128, 144) : chunk_size 185 | 30082.3 | 18063.3 | 33553.5 (144, 144) : chunk_size 208 | 38175.6 | 22839.5 | 42030.0 Times are in microseconds (us). ``` </details> Benchmark Script <details> ```python import functorch import torch import itertools import time from torch.utils.benchmark import Timer from torch.utils.benchmark import Compare import sys import pickle from torch import profiler import math def prod(l): prod = 1 for el in l: prod *= el return prod def fn(x, y): return x + y, x.sum(0) shapes = ((64, 64), (128, 128), (128, 144), (144, 144)) for device in ('cpu', 'cuda:1'): if device == 'cuda:1': chunks = (2, 3, 4, 100,) else: chunks = (2, 3, 4, 100, 200, 300) for chunk in chunks: results = [] for shape in shapes: x = torch.zeros(*shape, dtype=torch.float, device=device) y = x.sum() chunk_size = (prod(shape) + prod(shape[1:])) // chunk jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size) jacrev_fn_chunked_pre = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True) jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None) tasks = [("jacrev_fn_chunked(x, y)", "with chunk_size and stacked"), ("jacrev_fn(x, y)", "without chunk_size"), ("jacrev_fn_chunked_pre(x, y)", "with chunk_size and pre-allocated"),] timers = [Timer(stmt=stmt, label=f"jacrev : device {device} : chunks {chunk}", sub_label=f"{(shape)} : chunk_size {chunk_size}", description=desc, globals=globals()) for stmt, desc in tasks] for i, timer in enumerate(timers): results.append( timer.blocked_autorange(min_run_time=2.) ) print(f"\r{i + 1} / {len(timers)} : Shape {shape} : Device {device} : chunks: {chunk}", end="") sys.stdout.flush() print() comparison = Compare(results) comparison.print() ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89376 Approved by: https://github.com/zou3519
Author
Committer
Parents
Loading