[inductor] only check mutations attr for TritonKernel (#92277)
Fixes https://github.com/pytorch/pytorch/issues/93506.
In https://github.com/pytorch/pytorch/pull/91575, for in-place buffers reuse, a check has been added on the `mutations` attr of the kernel:
https://github.com/pytorch/pytorch/blob/5e0d3458eb58d21081f64d6a2347c5462453c2da/torch/_inductor/scheduler.py#L300
While `mutations` are not tracked in cpp kernels, `getattr(V.kernel, "mutations", None) is not None` will always be `False`.
This PR only checks the `mutations` attr for TritonKernel.
UT is added to guarantee that `in_out_ptr` is in the generated code.
#### Cpp code before this fix:
```python
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long i0=0; i0<8; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(8.0));
auto tmp2 = tmp0 / tmp1;
tmp2.store(out_ptr0 + 16*i0);
}
#pragma omp for simd simdlen(8)
for(long i0=128; i0<128; i0+=1)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = static_cast<float>(8.0);
auto tmp2 = tmp0 / tmp1;
out_ptr0[i0] = tmp2;
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
buf0 = empty_strided((2, 8, 8), (64, 8, 1), device='cpu', dtype=torch.float32)
extern_kernels.bmm(as_strided(arg0_1, (2, 8, 4), (32, 4, 1)), as_strided(arg1_1, (2, 4, 8), (32, 1, 4)), out=buf0)
del arg0_1
del arg1_1
buf1 = empty_strided((1, 2, 8, 8), (128, 64, 8, 1), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr()))
return (buf1, )
```
#### Cpp code after this fix:
```python
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(float* __restrict__ in_out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long i0=0; i0<8; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(8.0));
auto tmp2 = tmp0 / tmp1;
tmp2.store(in_out_ptr0 + 16*i0);
}
#pragma omp for simd simdlen(8)
for(long i0=128; i0<128; i0+=1)
{
auto tmp0 = in_out_ptr0[i0];
auto tmp1 = static_cast<float>(8.0);
auto tmp2 = tmp0 / tmp1;
in_out_ptr0[i0] = tmp2;
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
buf0 = empty_strided((2, 8, 8), (64, 8, 1), device='cpu', dtype=torch.float32)
extern_kernels.bmm(as_strided(arg0_1, (2, 8, 4), (32, 4, 1)), as_strided(arg1_1, (2, 4, 8), (32, 1, 4)), out=buf0)
del arg0_1
del arg1_1
buf1 = as_strided(buf0, (1, 2, 8, 8), (128, 64, 8, 1)); del buf0 # reuse
kernel_cpp_0(c_void_p(buf1.data_ptr()))
return (buf1, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92277
Approved by: https://github.com/jgong5, https://github.com/desertfire