pytorch
df14650f - [SDPA] Update SDPA API and make function Public (#92189)

Commit
1 year ago
[SDPA] Update SDPA API and make function Public (#92189) # Summary In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function. ## Changes ### API Previously the the function signature was: `scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)` Updated signature: `scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor` This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor. #### Reasoning: The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated. The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing. Discussed with folks at FAIR/Xformers and +1 this API change. #### Make function Public In preparation for the pt 2.0 launch we make the function public to start to generate user feedback Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading