jacrev : Support chunked computation (#89376)
Ref: https://github.com/pytorch/functorch/issues/680
We introduce a kwarg `chunk_size` in `jacrev` to control whether the Jacobian computation should be chunked and if so then `chunk_size` will dictate the maximum size of the chunks used.
We try two approaches,
* Stacked Approach: Append the intermediate computation to a list and then stack those results.
* Pre-allocation Approach: Pre-allocate a zeros tensor and copy chunked computation into it.
For Memory Benchmark, see https://github.com/pytorch/pytorch/pull/89376#issuecomment-1348479098
Benchmark CPU : Performs better with more chunks/ smaller chunk_size.
NOTE: There seems to be a lot of noise for shape `(64, 64)`.
<details>
```
[----------------------------------------------- jacrev : device cpu : chunks 2 -----------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 2080 | 76.2 | 50.9 | 80.1
(128, 128) : chunk_size 8256 | 1172.8 | 783.3 | 1225.5
(128, 144) : chunk_size 9288 | 1475.1 | 990.4 | 1548.3
(144, 144) : chunk_size 10440 | 1871.3 | 1254.4 | 1971.2
Times are in milliseconds (ms).
[----------------------------------------------- jacrev : device cpu : chunks 3 ----------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 1386 | 39.9 | 25.8 | 58.8
(128, 128) : chunk_size 5504 | 1182.6 | 782.2 | 1229.7
(128, 144) : chunk_size 6192 | 1483.6 | 995.4 | 1550.6
(144, 144) : chunk_size 6960 | 1879.1 | 1257.7 | 1960.5
Times are in milliseconds (ms).
[----------------------------------------------- jacrev : device cpu : chunks 4 ----------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 1040 | 41.7 | 50.6 | 29.1
(128, 128) : chunk_size 4128 | 1171.6 | 782.3 | 1226.7
(128, 144) : chunk_size 4644 | 1482.2 | 994.6 | 1550.9
(144, 144) : chunk_size 5220 | 1870.2 | 1254.5 | 1961.4
Times are in milliseconds (ms).
[--------------------------------------------- jacrev : device cpu : chunks 100 ---------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 41 | 46.8 | 50.5 | 46.4
(128, 128) : chunk_size 165 | 622.2 | 775.2 | 656.0
(128, 144) : chunk_size 185 | 803.9 | 987.3 | 866.9
(144, 144) : chunk_size 208 | 1021.1 | 1251.2 | 1088.2
Times are in milliseconds (ms).
[--------------------------------------------- jacrev : device cpu : chunks 200 ---------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 20 | 60.9 | 50.2 | 62.3
(128, 128) : chunk_size 82 | 583.1 | 779.4 | 634.3
(128, 144) : chunk_size 92 | 834.1 | 1005.8 | 472.3
(144, 144) : chunk_size 104 | 1053.6 | 1277.0 | 1033.9
Times are in milliseconds (ms).
[--------------------------------------------- jacrev : device cpu : chunks 300 --------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: ------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 13 | 77.7 | 50.4 | 79.6
(128, 128) : chunk_size 55 | 578.9 | 782.3 | 626.9
(128, 144) : chunk_size 61 | 718.2 | 1024.9 | 800.4
(144, 144) : chunk_size 69 | 919.7 | 1313.7 | 1023.0
Times are in milliseconds (ms).
```
</details>
Benchmark CUDA: Performs better with less chunks/bigger chunk_size.
<details>
```
[--------------------------------------------- jacrev : device cuda:1 : chunks 2 ----------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 2080 | 1485.7 | 923.8 | 1632.3
(128, 128) : chunk_size 8256 | 25390.2 | 14103.2 | 33557.4
(128, 144) : chunk_size 9288 | 801.7 | 16854.1 | 42894.6
(144, 144) : chunk_size 10440 | 1003.5 | 21386.5 | 59648.5
Times are in microseconds (us).
3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 3
[--------------------------------------------- jacrev : device cuda:1 : chunks 3 ---------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 1386 | 1474.5 | 924.5 | 1655.5
(128, 128) : chunk_size 5504 | 25368.9 | 10156.0 | 34022.1
(128, 144) : chunk_size 6192 | 25223.0 | 12933.7 | 56418.5
(144, 144) : chunk_size 6960 | 24729.3 | 16367.4 | 68744.7
Times are in microseconds (us).
3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 4
[--------------------------------------------- jacrev : device cuda:1 : chunks 4 ---------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 1040 | 1489.2 | 924.4 | 1679.6
(128, 128) : chunk_size 4128 | 25370.4 | 8987.4 | 57201.3
(128, 144) : chunk_size 4644 | 32239.1 | 10136.2 | 72406.5
(144, 144) : chunk_size 5220 | 40994.3 | 12867.8 | 108653.4
Times are in microseconds (us).
3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 100
[------------------------------------------- jacrev : device cuda:1 : chunks 100 --------------------------------------------]
| with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
(64, 64) : chunk_size 41 | 21121.8 | 924.2 | 22753.5
(128, 128) : chunk_size 165 | 23679.7 | 14284.4 | 26758.2
(128, 144) : chunk_size 185 | 30082.3 | 18063.3 | 33553.5
(144, 144) : chunk_size 208 | 38175.6 | 22839.5 | 42030.0
Times are in microseconds (us).
```
</details>
Benchmark Script
<details>
```python
import functorch
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
from torch import profiler
import math
def prod(l):
prod = 1
for el in l:
prod *= el
return prod
def fn(x, y):
return x + y, x.sum(0)
shapes = ((64, 64), (128, 128), (128, 144), (144, 144))
for device in ('cpu', 'cuda:1'):
if device == 'cuda:1':
chunks = (2, 3, 4, 100,)
else:
chunks = (2, 3, 4, 100, 200, 300)
for chunk in chunks:
results = []
for shape in shapes:
x = torch.zeros(*shape, dtype=torch.float, device=device)
y = x.sum()
chunk_size = (prod(shape) + prod(shape[1:])) // chunk
jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size)
jacrev_fn_chunked_pre = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True)
jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None)
tasks = [("jacrev_fn_chunked(x, y)", "with chunk_size and stacked"),
("jacrev_fn(x, y)", "without chunk_size"),
("jacrev_fn_chunked_pre(x, y)", "with chunk_size and pre-allocated"),]
timers = [Timer(stmt=stmt, label=f"jacrev : device {device} : chunks {chunk}", sub_label=f"{(shape)} : chunk_size {chunk_size}", description=desc, globals=globals()) for stmt, desc in tasks]
for i, timer in enumerate(timers):
results.append(
timer.blocked_autorange(min_run_time=2.)
)
print(f"\r{i + 1} / {len(timers)} : Shape {shape} : Device {device} : chunks: {chunk}", end="")
sys.stdout.flush()
print()
comparison = Compare(results)
comparison.print()
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89376
Approved by: https://github.com/zou3519