[inductor] optimize lowering for empty-related operators (#91350)
For micro-benchmark, `new_empty_strided` and `new_empty` have poor performance with inductor compared to eager. The main reason is that inductor initializes new tensor with 0 during lowering, which generates a useless cpp kernel. Actually, it is not needed for operator semantics, but costs additional time. The same problem is also found in lowerings of `empty_strided` and `empty`. This PR tends to remove useless cpp kernel of tensor initialization by generating a NopKernelSchedulerNode instead of a SchedulerNode. The lowering functions of following operators are optimized:
- `torch.empty`
- `aten.empty`
- `aten.new_empty`
- `aten.empty_strided`
- `aten.new_empty_strided`
We take output code of `new_empty_strided` as example.
_Before change_
```
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_root/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(float* __restrict__ out_ptr0)
{
#pragma omp parallel num_threads(28)
{
#pragma omp for
for(long i0=0; i0<57600; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>(static_cast<float>(0));
tmp0.store(out_ptr0 + 16*i0);
}
#pragma omp for simd simdlen(8)
for(long i0=921600; i0<921600; i0+=1)
{
auto tmp0 = static_cast<float>(0);
out_ptr0[i0] = tmp0;
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
buf0 = empty_strided((60, 60, 256), (15360, 256, 1), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(buf0.data_ptr()))
return (buf0, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((60, 60, 256), (60, 1, 3600), device='cpu', dtype=torch.float32)
print_performance(lambda: call([arg0_1]))
```
_After change_
```
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
buf0 = empty_strided((60, 60, 256), (15360, 256, 1), device='cpu', dtype=torch.float32)
return (buf0, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((60, 60, 256), (60, 1, 3600), device='cpu', dtype=torch.float32)
print_performance(lambda: call([arg0_1]))
```
Performance data for eager v.s. inductor:
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/xuanliao/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/xuanliao/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
</head>
<body link="#0563C1" vlink="#954F72">
suite | op_name | improved_ratio_speedup0.2 | improved_ratio_speedup0.5 | improved_ratio_speedup0.8 | speedup_old_0.2 | RSD(3) | speedup_old_0.5 | RSD(3) | speedup_old_0.8 | RSD(3) | speedup_new_0.2 | RSD(3) | speedup_new_0.5 | RSD(3) | speedup_new_0.8 | RSD(3)
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
torchbench | aten.new_empty_strided.default | 235.94% | 100.94% | 50.23% | 0.325947 | 2.96% | 0.550267 | 2.03% | 0.747997 | 2.93% | 1.094985 | 0.81% | 1.105722 | 0.55% | 1.12372 | 0.68%
huggingface | aten.new_empty_strided.default | 120.58% | 81.16% | 87.41% | 0.503116 | 28.27% | 0.668831 | 5.85% | 0.705637 | 2.76% | 1.109785 | 1.70% | 1.211641 | 0.74% | 1.322434 | 0.82%
timm | aten.new_empty_strided.default | 129.24% | 72.75% | 47.91% | 0.490658 | 15.87% | 0.76711 | 13.11% | 0.904033 | 4.44% | 1.124806 | 1.19% | 1.325182 | 0.65% | 1.337114 | 1.01%
torchbench | aten.new_empty.default | 69.41% | 1.60% | 0.90% | 0.732117 | 5.24% | 1.228356 | 1.18% | 1.241341 | 0.81% | 1.24031 | 1.96% | 1.248061 | 1.70% | 1.252525 | 1.84%
huggingface | aten.new_empty.default | 150.01% | 79.29% | 39.91% | 0.49547 | 12.67% | 0.692498 | 22.11% | 0.889526 | 27.37% | 1.238706 | 1.58% | 1.241606 | 1.49% | 1.244506 | 1.41%
timm | aten.new_empty.default | 11.61% | 11.13% | 11.07% | 1.115127 | 0.65% | 1.124302 | 0.80% | 1.132986 | 1.38% | 1.244582 | 1.12% | 1.249459 | 1.31% | 1.258416 | 1.14%
</body>
</html>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91350
Approved by: https://github.com/EikanWang, https://github.com/anijain2305, https://github.com/jgong5, https://github.com/desertfire