pytorch
679e869a - [inductor] only check mutations attr for TritonKernel (#92277)

Commit
2 years ago
[inductor] only check mutations attr for TritonKernel (#92277) Fixes https://github.com/pytorch/pytorch/issues/93506. In https://github.com/pytorch/pytorch/pull/91575, for in-place buffers reuse, a check has been added on the `mutations` attr of the kernel: https://github.com/pytorch/pytorch/blob/5e0d3458eb58d21081f64d6a2347c5462453c2da/torch/_inductor/scheduler.py#L300 While `mutations` are not tracked in cpp kernels, `getattr(V.kernel, "mutations", None) is not None` will always be `False`. This PR only checks the `mutations` attr for TritonKernel. UT is added to guarantee that `in_out_ptr` is in the generated code. #### Cpp code before this fix: ```python kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, float* __restrict__ out_ptr0) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long i0=0; i0<8; i0+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(8.0)); auto tmp2 = tmp0 / tmp1; tmp2.store(out_ptr0 + 16*i0); } #pragma omp for simd simdlen(8) for(long i0=128; i0<128; i0+=1) { auto tmp0 = in_ptr0[i0]; auto tmp1 = static_cast<float>(8.0); auto tmp2 = tmp0 / tmp1; out_ptr0[i0] = tmp2; } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1 = args args.clear() buf0 = empty_strided((2, 8, 8), (64, 8, 1), device='cpu', dtype=torch.float32) extern_kernels.bmm(as_strided(arg0_1, (2, 8, 4), (32, 4, 1)), as_strided(arg1_1, (2, 4, 8), (32, 1, 4)), out=buf0) del arg0_1 del arg1_1 buf1 = empty_strided((1, 2, 8, 8), (128, 64, 8, 1), device='cpu', dtype=torch.float32) kernel_cpp_0(c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr())) return (buf1, ) ``` #### Cpp code after this fix: ```python kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(float* __restrict__ in_out_ptr0) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long i0=0; i0<8; i0+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 16*i0); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(8.0)); auto tmp2 = tmp0 / tmp1; tmp2.store(in_out_ptr0 + 16*i0); } #pragma omp for simd simdlen(8) for(long i0=128; i0<128; i0+=1) { auto tmp0 = in_out_ptr0[i0]; auto tmp1 = static_cast<float>(8.0); auto tmp2 = tmp0 / tmp1; in_out_ptr0[i0] = tmp2; } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1 = args args.clear() buf0 = empty_strided((2, 8, 8), (64, 8, 1), device='cpu', dtype=torch.float32) extern_kernels.bmm(as_strided(arg0_1, (2, 8, 4), (32, 4, 1)), as_strided(arg1_1, (2, 4, 8), (32, 1, 4)), out=buf0) del arg0_1 del arg1_1 buf1 = as_strided(buf0, (1, 2, 8, 8), (128, 64, 8, 1)); del buf0 # reuse kernel_cpp_0(c_void_p(buf1.data_ptr())) return (buf1, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/92277 Approved by: https://github.com/jgong5, https://github.com/desertfire
Author
Committer
Parents
Loading