Improve interpolate() speed for channels_last CPU videos (#90302)
This is the exact same PR as https://github.com/pytorch/pytorch/pull/86361, but on Videos (3D) instead of images (2D).
For torchvision training use-cases (num_threads=1), the speed-ups range in 1X-2X. When num_threads>1 the speed-ups are a lot higher, up to ~30X
Benchmarks details:
<details >
```
main branch=c6942dbbfbf836450898aa9a0c08aefe437d0765
input shape output size mode dtype num_threads speed-up main PR
(1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=1 1.0X 54.7ms vs 55.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=1 1.7X 40.5ms vs 24.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=1 1.4X 33.1ms vs 23.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=1 2.0X 47.5ms vs 24.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=1 1.7X 39.9ms vs 23.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=2 2.2X 54.6ms vs 25.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=2 2.3X 21.2ms vs 9.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=2 1.4X 16.5ms vs 12.0ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=2 2.6X 24.3ms vs 9.3ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=2 1.7X 19.9ms vs 12.0ms
(1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=12 10X 54.3ms vs 5.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=12 2.5X 4.1ms vs 1.6ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=12 1.4X 2.9ms vs 2.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=12 1.7X 4.8ms vs 2.8ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=12 1.7X 3.5ms vs 2.1ms
(1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=32 20X 54.2ms vs 2.7ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=32 1.5X 2.2ms vs 1.5ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=32 1.6X 1.3ms vs 0.8ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=32 1.3X 1.8ms vs 1.4ms
(1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=32 1.7X 1.3ms vs 0.8ms
(1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=1 1.0X 15.4ms vs 16.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=1 2.0X 12.3ms vs 6.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=1 1.6X 12.0ms vs 7.7ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=1 2.2X 13.1ms vs 6.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=1 1.7X 12.8ms vs 7.6ms
(1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=2 1.9X 15.5ms vs 8.2ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=2 2.0X 6.1ms vs 3.1ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=2 1.5X 6.0ms vs 3.9ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=2 2.2X 6.6ms vs 3.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=2 1.7X 6.5ms vs 3.9ms
(1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=12 11X 15.5ms vs 1.4ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=12 2.0X 1.1ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=12 1.6X 1.1ms vs 0.7ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=12 2.1X 1.2ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=12 1.5X 1.1ms vs 0.8ms
(1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=32 15X 15.4ms vs 1.0ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=32 1.7X 0.7ms vs 0.4ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=32 1.3X 0.7ms vs 0.5ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=32 3X 0.7ms vs 0.2ms
(1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=32 2.6X 0.7ms vs 0.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=1 1.0X 295.6ms vs 304.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=1 1.5X 223.2ms vs 144.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=1 1.5X 177.7ms vs 121.0ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=1 1.8X 258.6ms vs 145.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=1 1.6X 203.9ms vs 128.6ms
(1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=2 1.8X 295.4ms vs 160.4ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=2 1.5X 119.0ms vs 80.2ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=2 1.4X 84.8ms vs 60.6ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=2 1.7X 136.1ms vs 80.1ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=2 1.7X 102.2ms vs 60.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=12 9X 295.3ms vs 32.3ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=12 1.4X 25.2ms vs 18.7ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=12 1.4X 16.5ms vs 11.9ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=12 1.5X 28.1ms vs 18.8ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=12 1.7X 19.4ms vs 11.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=32 18X 294.7ms vs 16.2ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=32 1.2X 14.4ms vs 12.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=32 1.2X 5.9ms vs 4.8ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=32 1.2X 14.5ms vs 12.5ms
(1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=32 1.4X 6.9ms vs 4.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=1 0.9X 48.6ms vs 55.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=1 2.0X 38.8ms vs 19.2ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=1 1.6X 37.6ms vs 23.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=1 2.1X 41.2ms vs 19.2ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=1 1.7X 39.9ms vs 23.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=2 1.9X 48.8ms vs 25.3ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=2 2.0X 19.2ms vs 9.5ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=2 1.6X 18.8ms vs 12.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=2 2.2X 20.5ms vs 9.5ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=2 1.7X 20.0ms vs 12.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=12 11X 48.6ms vs 4.6ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=12 2.0X 3.4ms vs 1.7ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=12 1.6X 3.3ms vs 2.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=12 2.1X 3.6ms vs 1.7ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=12 1.7X 3.5ms vs 2.1ms
(1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=32 27X 48.3ms vs 1.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=32 1.1X 2.2ms vs 2.0ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=32 2.6X 2.1ms vs 0.8ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=32 2.4X 2.3ms vs 0.9ms
(1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=32 2.6X 2.2ms vs 0.8ms
```
</details>
Code:
<details>
```py
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for interpolate operator."""
class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
def init(self, input_size, output_size, channels_last=False, mode='linear', dtype=torch.float):
input_image = torch.randint(0, 256, size=input_size, dtype=dtype, device='cpu',
requires_grad=self.auto_set())
if channels_last:
if input_image.ndim == 4:
input_image = input_image.contiguous(memory_format=torch.channels_last)
elif input_image.ndim == 5:
input_image = input_image.contiguous(memory_format=torch.channels_last_3d)
else:
raise ValueError(
f"Can not set channels_last to the input of {input_image.ndim} dims"
)
align_corners = None if "nearest" in mode else False
if mode == "linear":
mode = {
3: 'linear',
4: 'bilinear',
5: 'trilinear',
}[input_image.ndim]
self.inputs = {
"input_image": input_image,
"output_size": output_size,
"mode": mode,
"align_corners": align_corners,
}
self.set_module_name("interpolate")
def forward(self, input_image, output_size, mode, align_corners):
return torch.nn.functional.interpolate(input_image, size=output_size, mode=mode,
align_corners=align_corners)
def make_config():
sizes = (
((16, 320, 320), (8, 256, 256)),
((16, 320, 320), (32, 512, 512)),
)
attrs = []
for (DHW1, DHW2) in sizes:
attrs.append([(1, 3, *DHW1), DHW2])
attrs.append([(1, 3, *DHW2), DHW1])
config = op_bench.config_list(
attr_names=["input_size", "output_size"],
attrs=attrs,
cross_product_configs={
'channels_last': [True],
'mode': ["linear", "nearest", "nearest-exact"],
'dtype': [torch.float, torch.uint8]
},
tags=["short"],
)
# Need to remove instances with both torch.int and linear
# Note: this is naaaasty
def get_mode(l):
for d in l:
if "mode" in d:
return d["mode"]
def get_dtype(l):
for d in l:
if "dtype" in d:
return d["dtype"]
config = [l for l in config if not(get_mode(l) == "linear" and get_dtype(l) == torch.uint8)]
return config
config = make_config()
op_bench.generate_pt_test(config, InterpolateBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
```
```py
import re
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("f3", nargs="?", default="main")
parser.add_argument("f2", nargs="?", default="new")
args = parser.parse_args()
with open(args.f1) as f:
main = f.readlines()
with open(args.f2) as f:
new = f.readlines()
out = []
for main_line, new_line in zip(main, new):
# num_threads=1 # TODO: remove
if main_line.startswith("num_threads="):
num_threads = int(main_line.split("=")[-1])
if main_line.startswith("# Input"):
deets = f"{main_line.strip()}, {num_threads=}"
if main_line.startswith("Forward"):
main_time = float(main_line.split()[-1])
new_time = float(new_line.split()[-1])
ratio = main_time / new_time
fmt = ".1f" if ratio < 3 else ".0f"
improv = f"{ratio:{fmt}}X"
time_fmt = ",.3f" if new_time < 100 else ",.1f"
deets = deets.strip().replace("# Input: ", "")
deets = deets.replace(": ", "=")
deets = deets.replace("input_size=", "")
deets = deets.replace(", output_size=", " -> ")
deets = deets.replace("dtype=torch.", "")
deets = deets.replace("mode=", "")
deets = deets.replace("channels_last=True, ", "")
split = deets.split(",")
size = ','.join(split[:-3])
mode, dtype, threads = split[-3:]
deets = f"{size:<30} {mode:<15} {dtype:<10} {threads:<15}"
l = f"{deets} {improv:<5} {main_time / 1000:{time_fmt}}ms vs {new_time / 1000:{time_fmt}}ms"
out.append(l)
def key(s):
# s = ''.join(s.split()[1:]) # remove "N.nX" part
num_threads = (int(re.findall(r"num_threads=(\d+)", s)[0]),)
input_shape, output_shape = re.findall("\(.*?\)", s)
input_shape = input_shape[1:-1] # remove parenthesis
input_HW = tuple(int(x) for x in input_shape.split(",")[-2:])
input_C = (-int(input_shape.split(",")[1]),)
output_HW = tuple(int(x) for x in output_shape[1:-1].split(","))
is_downsample = (output_HW[0] < input_HW[0],)
if "linear" in s:
mode = "linear"
elif "nearest-exact" in s:
mode = "nearest-exact"
else:
assert "nearest" in s
mode = "nearest"
mode = (mode,)
return is_downsample + input_HW + output_HW + num_threads + input_C + mode
for i, l in enumerate(sorted(out, key=key)):
if i % 5 == 0:
print()
# if i % 10 == 0 and i % 40 != 0:
# print()
# if i % 40 == 0:
# print("-" * 100)
print(l)
```
</details >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90302
Approved by: https://github.com/vfdev-5, https://github.com/fmassa