pytorch
b89b5716 - ROCm fixes for PyT2.0 (#100089)

Commit
1 year ago
ROCm fixes for PyT2.0 (#100089) This PR brings some updates and fixes in regards to PyT2.0 functionality 1 - ROCm's version of triton does not yet support tl.reduce Until supported we are opting to revert the removal of the aten.prod make_fallback for ROCm brought in with https://github.com/pytorch/pytorch/commit/7a6c650b8150231ad81869a3d1201ba0443ecad8 This issue was found locally with the latest aten.prod UTs on ROCm ``` FAILED [0.0916s] inductor/test_torchinductor.py::CudaTests::test_prod_cuda - torch._dynamo.exc.BackendCompilerFailed: backend='compile_fx_wrapper' raised: AttributeError: module 'triton.language' has no attribute 'reduce' ``` 2 - Adds aten.miopen_batch_norm as an explicit fallback as perf issues are observed when registered as a decomposition, setting warning=False as the fallback is expected 3 - Fixes a typo and redundant assignment in _inductor/triton_heuristics.py brought in with https://github.com/pytorch/pytorch/pull/99756/commits/dd778a76103a9a0c85cc487f851160824fe6124c Pull Request resolved: https://github.com/pytorch/pytorch/pull/100089 Approved by: https://github.com/kit1980, https://github.com/pruthvistony, https://github.com/jithunnair-amd, https://github.com/malfet, https://github.com/jansel
Author
Committer
Parents
Loading