pytorch
dfa94757 - Check SM version before calling flash attention with BFloat16 (#86600)

Commit
3 years ago
Check SM version before calling flash attention with BFloat16 (#86600) The flash attention code path requires sm80 or newer to run on BFloat16, so any OpInfo tests running with BFloat16 would fail with the error: ``` RuntimeError: Expected q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16) to be true, but got false. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/86600 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading