adds triton flash attention2 kernel (#4337)
* initial commit
* temp commit: needs debugging
* packed flash attn with mask works
* clean-up
* add bert/roberta tests to test_inference
* is_triton_supported added to Accelerator class
clean-up and formatting
* triton supports the flash attention when compute cap > 8.0
* formatting
* fix comments
* cleanup
* cleanup flash kernel
* fix according to the PR comment
---------
Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>