add support for auto-tune TMA grid constant (#2428)
Summary:
This PR follows [a recent PR in Triton](https://github.com/triton-lang/triton/pull/4498) that supports passing TMA descriptors by-value using `__grid_constant__`.
In this PR, we update the kernel `_attn_fwd_inner` to support the above new feature in Triton. To support auto-tune, we implement a helper class that wraps operations for TMA during auto-tune and computations in kernel respectively.
In addition, the benchmark program now also checks whether the triton version supports this new feature. If it doesn't, the helper class applies the old way of handling TMA.
The change has been tested on Triton from the standard installation of pytorch on conda, as well as the recent Triton including the above PR.
Command for testing and experiment results:
Before removing fences: P1541573348
After removing fences: P1541736645
1) CUDA_VISIBLE_DEVICES=5, old tma: 138.476
2) CUDA_VISIBLE_DEVICES=5, new tma, with fences: 152 - 164
3) CUDA_VISIBLE_DEVICES=5, new tma, after removing fences: 168.0
4) CUDA_VISIBLE_DEVICES=5, no tma: 187.881
The result is still behind no TMA and we can investigate further.
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2428
Reviewed By: embg
Differential Revision: D61668142
Pulled By: sfzhu93
fbshipit-source-id: d08bab147c6b2197f73447ee8f30ede877e712ca