pytorch
745dc3a1 - [inductor] optimize lowering for empty-related operators (#91350)

Commit
1 year ago
[inductor] optimize lowering for empty-related operators (#91350) For micro-benchmark, `new_empty_strided` and `new_empty` have poor performance with inductor compared to eager. The main reason is that inductor initializes new tensor with 0 during lowering, which generates a useless cpp kernel. Actually, it is not needed for operator semantics, but costs additional time. The same problem is also found in lowerings of `empty_strided` and `empty`. This PR tends to remove useless cpp kernel of tensor initialization by generating a NopKernelSchedulerNode instead of a SchedulerNode. The lowering functions of following operators are optimized: - `torch.empty` - `aten.empty` - `aten.new_empty` - `aten.empty_strided` - `aten.new_empty_strided` We take output code of `new_empty_strided` as example. _Before change_ ``` kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_root/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(float* __restrict__ out_ptr0) { #pragma omp parallel num_threads(28) { #pragma omp for for(long i0=0; i0<57600; i0+=1) { auto tmp0 = at::vec::Vectorized<float>(static_cast<float>(0)); tmp0.store(out_ptr0 + 16*i0); } #pragma omp for simd simdlen(8) for(long i0=921600; i0<921600; i0+=1) { auto tmp0 = static_cast<float>(0); out_ptr0[i0] = tmp0; } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() buf0 = empty_strided((60, 60, 256), (15360, 256, 1), device='cpu', dtype=torch.float32) kernel_cpp_0(c_void_p(buf0.data_ptr())) return (buf0, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((60, 60, 256), (60, 1, 3600), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1])) ``` _After change_ ``` async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() buf0 = empty_strided((60, 60, 256), (15360, 256, 1), device='cpu', dtype=torch.float32) return (buf0, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((60, 60, 256), (60, 1, 3600), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1])) ``` Performance data for eager v.s. inductor: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/xuanliao/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/xuanliao/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> </head> <body link="#0563C1" vlink="#954F72"> suite | op_name | improved_ratio_speedup0.2 | improved_ratio_speedup0.5 | improved_ratio_speedup0.8 | speedup_old_0.2 | RSD(3) | speedup_old_0.5 | RSD(3) | speedup_old_0.8 | RSD(3) | speedup_new_0.2 | RSD(3) | speedup_new_0.5 | RSD(3) | speedup_new_0.8 | RSD(3) -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- torchbench | aten.new_empty_strided.default | 235.94% | 100.94% | 50.23% | 0.325947 | 2.96% | 0.550267 | 2.03% | 0.747997 | 2.93% | 1.094985 | 0.81% | 1.105722 | 0.55% | 1.12372 | 0.68% huggingface | aten.new_empty_strided.default | 120.58% | 81.16% | 87.41% | 0.503116 | 28.27% | 0.668831 | 5.85% | 0.705637 | 2.76% | 1.109785 | 1.70% | 1.211641 | 0.74% | 1.322434 | 0.82% timm | aten.new_empty_strided.default | 129.24% | 72.75% | 47.91% | 0.490658 | 15.87% | 0.76711 | 13.11% | 0.904033 | 4.44% | 1.124806 | 1.19% | 1.325182 | 0.65% | 1.337114 | 1.01% torchbench | aten.new_empty.default | 69.41% | 1.60% | 0.90% | 0.732117 | 5.24% | 1.228356 | 1.18% | 1.241341 | 0.81% | 1.24031 | 1.96% | 1.248061 | 1.70% | 1.252525 | 1.84% huggingface | aten.new_empty.default | 150.01% | 79.29% | 39.91% | 0.49547 | 12.67% | 0.692498 | 22.11% | 0.889526 | 27.37% | 1.238706 | 1.58% | 1.241606 | 1.49% | 1.244506 | 1.41% timm | aten.new_empty.default | 11.61% | 11.13% | 11.07% | 1.115127 | 0.65% | 1.124302 | 0.80% | 1.132986 | 1.38% | 1.244582 | 1.12% | 1.249459 | 1.31% | 1.258416 | 1.14% </body> </html> Pull Request resolved: https://github.com/pytorch/pytorch/pull/91350 Approved by: https://github.com/EikanWang, https://github.com/anijain2305, https://github.com/jgong5, https://github.com/desertfire
Author
Committer
Parents
Loading