pytorch
9a2becf6 - inductor: fix inplace op's wrong lowering issue when preop is NopKernel (#92247)

Commit
2 years ago
inductor: fix inplace op's wrong lowering issue when preop is NopKernel (#92247) For TIMM ghostnet_100, there has such case, concat+inplace_add: ``` import torch from torch._inductor import config config.debug = True torch._dynamo.config.verbose=True class MockModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y, z): out = torch.cat([x, y], dim=1) out+=z return out mod = MockModule().eval() inputs = ( torch.randn([1, 64, 16, 16]), torch.randn([1, 64, 16, 16]), torch.randn([1, 128, 16, 16]), ) ref = mod(*inputs) with torch.no_grad(): opt_model = torch._dynamo.optimize('inductor')(mod) out = opt_model(*inputs) out = opt_model(*inputs) out = opt_model(*inputs) print(torch.equal(ref, out)) ``` the inductor always get a wrong result, I find that inductor get a wrong code: ``` from ctypes import c_void_p, c_long import torch import random from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, const float* __restrict__ in_ptr1, const float* __restrict__ in_ptr2, const float* __restrict__ in_ptr3, float* __restrict__ out_ptr0, float* __restrict__ out_ptr1, float* __restrict__ out_ptr2) { { for(long i0=0; i0<1024; i0+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0); tmp0.store(out_ptr0 + 16*i0); } #pragma omp simd simdlen(8) for(long i0=16384; i0<16384; i0+=1) { auto tmp0 = in_ptr0[i0]; out_ptr0[i0] = tmp0; } } { for(long i0=0; i0<1024; i0+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + 16*i0); tmp0.store(out_ptr1 + 16*i0); } #pragma omp simd simdlen(8) for(long i0=16384; i0<16384; i0+=1) { auto tmp0 = in_ptr1[i0]; out_ptr1[i0] = tmp0; } } { for(long i0=0; i0<2048; i0+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr2 + 16*i0); auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr3 + 16*i0); auto tmp2 = tmp0 + tmp1; tmp2.store(out_ptr2 + 16*i0); } #pragma omp simd simdlen(8) for(long i0=32768; i0<32768; i0+=1) { auto tmp0 = in_ptr2[i0]; auto tmp1 = in_ptr3[i0]; auto tmp2 = tmp0 + tmp1; out_ptr2[i0] = tmp2; } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1 = args args.clear() buf3 = empty_strided((1, 128, 16, 16), (32768, 256, 16, 1), device='cpu', dtype=torch.float32) buf0 = as_strided(buf3, (1, 64, 16, 16), (32768, 256, 16, 1)) # alias buf1 = as_strided(buf3, (1, 64, 16, 16), (32768, 256, 16, 1), 16384) # alias buf2 = empty_strided((1, 128, 16, 16), (32768, 256, 16, 1), device='cpu', dtype=torch.float32) kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(arg2_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf3.data_ptr())) del arg0_1 del arg1_1 del arg2_1 return (buf3, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((1, 64, 16, 16), (16384, 256, 16, 1), device='cpu', dtype=torch.float32) arg1_1 = rand_strided((1, 64, 16, 16), (16384, 256, 16, 1), device='cpu', dtype=torch.float32) arg2_1 = rand_strided((1, 128, 16, 16), (32768, 256, 16, 1), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1, arg1_1, arg2_1])) ``` you can see that the add operation always adds a random value, see the ir code: 1. **ir_pre_fusion.txt** ``` buf0: SchedulerNode(ComputedBuffer) buf0.writes = [MemoryDep(name='buf0', index=c0, size=(16384,))] buf0.unmet_dependencies = [] buf0.met_dependencies = [MemoryDep(name='arg0_1', index=c0, size=(16384,))] buf0.group.device = cpu buf0.group.iteration = ((16384,), ()) buf0.sizes = ([16384], []) buf0.aliases = ['buf3'] class buf0_loop_body: var_ranges = {z0: 16384} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg0_1', get_index) get_index_1 = self.get_index('index0') store = ops.store('buf0', get_index_1, load, None) return store buf1: SchedulerNode(ComputedBuffer) buf1.writes = [MemoryDep(name='buf1', index=c0, size=(16384,))] buf1.unmet_dependencies = [] buf1.met_dependencies = [MemoryDep(name='arg1_1', index=c0, size=(16384,))] buf1.group.device = cpu buf1.group.iteration = ((16384,), ()) buf1.sizes = ([16384], []) buf1.aliases = ['buf3'] class buf1_loop_body: var_ranges = {z0: 16384} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg1_1', get_index) get_index_1 = self.get_index('index0') store = ops.store('buf1', get_index_1, load, None) return store buf2: NopKernelSchedulerNode(ConcatKernel) buf2.writes = [StarDep(name='buf2')] buf2.unmet_dependencies = [StarDep(name='buf0'), StarDep(name='buf1')] buf2.met_dependencies = [] buf3: SchedulerNode(ComputedBuffer) buf3.writes = [MemoryDep(name='buf3', index=c0, size=(32768,))] buf3.unmet_dependencies = [MemoryDep(name='buf2', index=c0, size=(32768,))] buf3.met_dependencies = [MemoryDep(name='arg2_1', index=c0, size=(32768,))] buf3.group.device = cpu buf3.group.iteration = ((32768,), ()) buf3.sizes = ([32768], []) class buf3_loop_body: var_ranges = {z0: 32768} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('buf2', get_index) get_index_1 = self.get_index('index0') load_1 = ops.load('arg2_1', get_index_1) add = ops.add(load, load_1) get_index_2 = self.get_index('index0') store = ops.store('buf3', get_index_2, add, None) return store ``` 2. **ir_post_fusion.txt** ``` buf0: SchedulerNode(ComputedBuffer) buf0.writes = [MemoryDep(name='buf0', index=c0, size=(16384,))] buf0.unmet_dependencies = [] buf0.met_dependencies = [MemoryDep(name='arg0_1', index=c0, size=(16384,))] buf0.group.device = cpu buf0.group.iteration = ((16384,), ()) buf0.sizes = ([16384], []) buf0.aliases = ['buf3'] class buf0_loop_body: var_ranges = {z0: 16384} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg0_1', get_index) get_index_1 = self.get_index('index0') store = ops.store('buf0', get_index_1, load, None) return store buf1: SchedulerNode(ComputedBuffer) buf1.writes = [MemoryDep(name='buf1', index=c0, size=(16384,))] buf1.unmet_dependencies = [] buf1.met_dependencies = [MemoryDep(name='arg1_1', index=c0, size=(16384,))] buf1.group.device = cpu buf1.group.iteration = ((16384,), ()) buf1.sizes = ([16384], []) buf1.aliases = ['buf3'] class buf1_loop_body: var_ranges = {z0: 16384} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg1_1', get_index) get_index_1 = self.get_index('index0') store = ops.store('buf1', get_index_1, load, None) return store buf2: NopKernelSchedulerNode(ConcatKernel) buf2.writes = [StarDep(name='buf2')] buf2.unmet_dependencies = [StarDep(name='buf0'), StarDep(name='buf1')] buf2.met_dependencies = [] buf3: SchedulerNode(ComputedBuffer) buf3.writes = [MemoryDep(name='buf3', index=c0, size=(32768,))] buf3.unmet_dependencies = [MemoryDep(name='buf2', index=c0, size=(32768,))] buf3.met_dependencies = [MemoryDep(name='arg2_1', index=c0, size=(32768,))] buf3.group.device = cpu buf3.group.iteration = ((32768,), ()) buf3.sizes = ([32768], []) class buf3_loop_body: var_ranges = {z0: 32768} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('buf2', get_index) get_index_1 = self.get_index('index0') load_1 = ops.load('arg2_1', get_index_1) add = ops.add(load, load_1) get_index_2 = self.get_index('index0') store = ops.store('buf3', get_index_2, add, None) return store ``` From the ir code, you can see the buf3 always adds an empty buf2 which has never been written. The root cause is that there has a potential issue when doing the mutation for inplace add when its' input is a NopKernel. After this PR, the ir will be like(**ir_pre_fusion.txt**): ``` buf0: SchedulerNode(ComputedBuffer) buf0.writes = [MemoryDep(name='buf0', index=c0, size=(16384,))] buf0.unmet_dependencies = [] buf0.met_dependencies = [MemoryDep(name='arg0_1', index=c0, size=(16384,))] buf0.group.device = cpu buf0.group.iteration = ((16384,), ()) buf0.sizes = ([16384], []) buf0.aliases = ['buf2'] class buf0_loop_body: var_ranges = {z0: 16384} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg0_1', get_index) get_index_1 = self.get_index('index0') store = ops.store('buf0', get_index_1, load, None) return store buf1: SchedulerNode(ComputedBuffer) buf1.writes = [MemoryDep(name='buf1', index=c0, size=(16384,))] buf1.unmet_dependencies = [] buf1.met_dependencies = [MemoryDep(name='arg1_1', index=c0, size=(16384,))] buf1.group.device = cpu buf1.group.iteration = ((16384,), ()) buf1.sizes = ([16384], []) buf1.aliases = ['buf2'] class buf1_loop_body: var_ranges = {z0: 16384} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg1_1', get_index) get_index_1 = self.get_index('index0') store = ops.store('buf1', get_index_1, load, None) return store buf2: NopKernelSchedulerNode(ConcatKernel) buf2.writes = [StarDep(name='buf2')] buf2.unmet_dependencies = [StarDep(name='buf0'), StarDep(name='buf1')] buf2.met_dependencies = [] buf3: SchedulerNode(ComputedBuffer) buf3.writes = [MemoryDep(name='buf3', index=c0, size=(32768,))] buf3.unmet_dependencies = [MemoryDep(name='buf2', index=c0, size=(32768,)), StarDep(name='buf2')] buf3.met_dependencies = [MemoryDep(name='arg2_1', index=c0, size=(32768,))] buf3.group.device = cpu buf3.group.iteration = ((32768,), ()) buf3.sizes = ([32768], []) buf3.mutations = ['buf2'] class buf3_loop_body: var_ranges = {z0: 32768} index0 = z0 def body(self, ops): get_index = self.get_index('index0') load = ops.load('buf2', get_index) get_index_1 = self.get_index('index0') load_1 = ops.load('arg2_1', get_index_1) add = ops.add(load, load_1) get_index_2 = self.get_index('index0') store = ops.store('buf3', get_index_2, add, None) return store ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/92247 Approved by: https://github.com/ngimel, https://github.com/desertfire, https://github.com/jansel
Author
Committer
Parents
Loading