pytorch
8a21cac3 - Improve interpolate() speed for channels_last CPU videos (#90302)

Commit
2 years ago
Improve interpolate() speed for channels_last CPU videos (#90302) This is the exact same PR as https://github.com/pytorch/pytorch/pull/86361, but on Videos (3D) instead of images (2D). For torchvision training use-cases (num_threads=1), the speed-ups range in 1X-2X. When num_threads>1 the speed-ups are a lot higher, up to ~30X Benchmarks details: <details > ``` main branch=c6942dbbfbf836450898aa9a0c08aefe437d0765 input shape output size mode dtype num_threads speed-up main PR (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=1 1.0X 54.7ms vs 55.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=1 1.7X 40.5ms vs 24.4ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=1 1.4X 33.1ms vs 23.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=1 2.0X 47.5ms vs 24.3ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=1 1.7X 39.9ms vs 23.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=2 2.2X 54.6ms vs 25.1ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=2 2.3X 21.2ms vs 9.3ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=2 1.4X 16.5ms vs 12.0ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=2 2.6X 24.3ms vs 9.3ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=2 1.7X 19.9ms vs 12.0ms (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=12 10X 54.3ms vs 5.4ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=12 2.5X 4.1ms vs 1.6ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=12 1.4X 2.9ms vs 2.1ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=12 1.7X 4.8ms vs 2.8ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=12 1.7X 3.5ms vs 2.1ms (1, 3, 8, 256, 256) -> (16, 320, 320) linear float32 num_threads=32 20X 54.2ms vs 2.7ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest float32 num_threads=32 1.5X 2.2ms vs 1.5ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest uint8 num_threads=32 1.6X 1.3ms vs 0.8ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact float32 num_threads=32 1.3X 1.8ms vs 1.4ms (1, 3, 8, 256, 256) -> (16, 320, 320) nearest-exact uint8 num_threads=32 1.7X 1.3ms vs 0.8ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=1 1.0X 15.4ms vs 16.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=1 2.0X 12.3ms vs 6.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=1 1.6X 12.0ms vs 7.7ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=1 2.2X 13.1ms vs 6.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=1 1.7X 12.8ms vs 7.6ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=2 1.9X 15.5ms vs 8.2ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=2 2.0X 6.1ms vs 3.1ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=2 1.5X 6.0ms vs 3.9ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=2 2.2X 6.6ms vs 3.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=2 1.7X 6.5ms vs 3.9ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=12 11X 15.5ms vs 1.4ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=12 2.0X 1.1ms vs 0.5ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=12 1.6X 1.1ms vs 0.7ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=12 2.1X 1.2ms vs 0.5ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=12 1.5X 1.1ms vs 0.8ms (1, 3, 16, 320, 320) -> (8, 256, 256) linear float32 num_threads=32 15X 15.4ms vs 1.0ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest float32 num_threads=32 1.7X 0.7ms vs 0.4ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest uint8 num_threads=32 1.3X 0.7ms vs 0.5ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact float32 num_threads=32 3X 0.7ms vs 0.2ms (1, 3, 16, 320, 320) -> (8, 256, 256) nearest-exact uint8 num_threads=32 2.6X 0.7ms vs 0.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=1 1.0X 295.6ms vs 304.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=1 1.5X 223.2ms vs 144.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=1 1.5X 177.7ms vs 121.0ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=1 1.8X 258.6ms vs 145.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=1 1.6X 203.9ms vs 128.6ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=2 1.8X 295.4ms vs 160.4ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=2 1.5X 119.0ms vs 80.2ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=2 1.4X 84.8ms vs 60.6ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=2 1.7X 136.1ms vs 80.1ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=2 1.7X 102.2ms vs 60.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=12 9X 295.3ms vs 32.3ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=12 1.4X 25.2ms vs 18.7ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=12 1.4X 16.5ms vs 11.9ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=12 1.5X 28.1ms vs 18.8ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=12 1.7X 19.4ms vs 11.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) linear float32 num_threads=32 18X 294.7ms vs 16.2ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest float32 num_threads=32 1.2X 14.4ms vs 12.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest uint8 num_threads=32 1.2X 5.9ms vs 4.8ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact float32 num_threads=32 1.2X 14.5ms vs 12.5ms (1, 3, 16, 320, 320) -> (32, 512, 512) nearest-exact uint8 num_threads=32 1.4X 6.9ms vs 4.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=1 0.9X 48.6ms vs 55.1ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=1 2.0X 38.8ms vs 19.2ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=1 1.6X 37.6ms vs 23.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=1 2.1X 41.2ms vs 19.2ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=1 1.7X 39.9ms vs 23.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=2 1.9X 48.8ms vs 25.3ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=2 2.0X 19.2ms vs 9.5ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=2 1.6X 18.8ms vs 12.0ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=2 2.2X 20.5ms vs 9.5ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=2 1.7X 20.0ms vs 12.0ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=12 11X 48.6ms vs 4.6ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=12 2.0X 3.4ms vs 1.7ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=12 1.6X 3.3ms vs 2.1ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=12 2.1X 3.6ms vs 1.7ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=12 1.7X 3.5ms vs 2.1ms (1, 3, 32, 512, 512) -> (16, 320, 320) linear float32 num_threads=32 27X 48.3ms vs 1.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest float32 num_threads=32 1.1X 2.2ms vs 2.0ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest uint8 num_threads=32 2.6X 2.1ms vs 0.8ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact float32 num_threads=32 2.4X 2.3ms vs 0.9ms (1, 3, 32, 512, 512) -> (16, 320, 320) nearest-exact uint8 num_threads=32 2.6X 2.2ms vs 0.8ms ``` </details> Code: <details> ```py import operator_benchmark as op_bench import torch """Microbenchmarks for interpolate operator.""" class InterpolateBenchmark(op_bench.TorchBenchmarkBase): def init(self, input_size, output_size, channels_last=False, mode='linear', dtype=torch.float): input_image = torch.randint(0, 256, size=input_size, dtype=dtype, device='cpu', requires_grad=self.auto_set()) if channels_last: if input_image.ndim == 4: input_image = input_image.contiguous(memory_format=torch.channels_last) elif input_image.ndim == 5: input_image = input_image.contiguous(memory_format=torch.channels_last_3d) else: raise ValueError( f"Can not set channels_last to the input of {input_image.ndim} dims" ) align_corners = None if "nearest" in mode else False if mode == "linear": mode = { 3: 'linear', 4: 'bilinear', 5: 'trilinear', }[input_image.ndim] self.inputs = { "input_image": input_image, "output_size": output_size, "mode": mode, "align_corners": align_corners, } self.set_module_name("interpolate") def forward(self, input_image, output_size, mode, align_corners): return torch.nn.functional.interpolate(input_image, size=output_size, mode=mode, align_corners=align_corners) def make_config(): sizes = ( ((16, 320, 320), (8, 256, 256)), ((16, 320, 320), (32, 512, 512)), ) attrs = [] for (DHW1, DHW2) in sizes: attrs.append([(1, 3, *DHW1), DHW2]) attrs.append([(1, 3, *DHW2), DHW1]) config = op_bench.config_list( attr_names=["input_size", "output_size"], attrs=attrs, cross_product_configs={ 'channels_last': [True], 'mode': ["linear", "nearest", "nearest-exact"], 'dtype': [torch.float, torch.uint8] }, tags=["short"], ) # Need to remove instances with both torch.int and linear # Note: this is naaaasty def get_mode(l): for d in l: if "mode" in d: return d["mode"] def get_dtype(l): for d in l: if "dtype" in d: return d["dtype"] config = [l for l in config if not(get_mode(l) == "linear" and get_dtype(l) == torch.uint8)] return config config = make_config() op_bench.generate_pt_test(config, InterpolateBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() ``` ```py import re import argparse parser = argparse.ArgumentParser() parser.add_argument("f3", nargs="?", default="main") parser.add_argument("f2", nargs="?", default="new") args = parser.parse_args() with open(args.f1) as f: main = f.readlines() with open(args.f2) as f: new = f.readlines() out = [] for main_line, new_line in zip(main, new): # num_threads=1 # TODO: remove if main_line.startswith("num_threads="): num_threads = int(main_line.split("=")[-1]) if main_line.startswith("# Input"): deets = f"{main_line.strip()}, {num_threads=}" if main_line.startswith("Forward"): main_time = float(main_line.split()[-1]) new_time = float(new_line.split()[-1]) ratio = main_time / new_time fmt = ".1f" if ratio < 3 else ".0f" improv = f"{ratio:{fmt}}X" time_fmt = ",.3f" if new_time < 100 else ",.1f" deets = deets.strip().replace("# Input: ", "") deets = deets.replace(": ", "=") deets = deets.replace("input_size=", "") deets = deets.replace(", output_size=", " -> ") deets = deets.replace("dtype=torch.", "") deets = deets.replace("mode=", "") deets = deets.replace("channels_last=True, ", "") split = deets.split(",") size = ','.join(split[:-3]) mode, dtype, threads = split[-3:] deets = f"{size:<30} {mode:<15} {dtype:<10} {threads:<15}" l = f"{deets} {improv:<5} {main_time / 1000:{time_fmt}}ms vs {new_time / 1000:{time_fmt}}ms" out.append(l) def key(s): # s = ''.join(s.split()[1:]) # remove "N.nX" part num_threads = (int(re.findall(r"num_threads=(\d+)", s)[0]),) input_shape, output_shape = re.findall("\(.*?\)", s) input_shape = input_shape[1:-1] # remove parenthesis input_HW = tuple(int(x) for x in input_shape.split(",")[-2:]) input_C = (-int(input_shape.split(",")[1]),) output_HW = tuple(int(x) for x in output_shape[1:-1].split(",")) is_downsample = (output_HW[0] < input_HW[0],) if "linear" in s: mode = "linear" elif "nearest-exact" in s: mode = "nearest-exact" else: assert "nearest" in s mode = "nearest" mode = (mode,) return is_downsample + input_HW + output_HW + num_threads + input_C + mode for i, l in enumerate(sorted(out, key=key)): if i % 5 == 0: print() # if i % 10 == 0 and i % 40 != 0: # print() # if i % 40 == 0: # print("-" * 100) print(l) ``` </details > Pull Request resolved: https://github.com/pytorch/pytorch/pull/90302 Approved by: https://github.com/vfdev-5, https://github.com/fmassa
Author
Committer
Parents
Loading