transformers
65753d60 - Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932)

Commit
1 year ago
Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932) * fix: fixes for graph breaks Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: formatting Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: import error Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: Add Fa2Kwargs Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Revert "PR changes" This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4. * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * addition of documentation Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * change in _flash_attention_forward Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * make fix-copies Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * revert make fix-copies Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix copies * style * loss kwargs typing * style and pull latest changes --------- Signed-off-by: Abhishek <maurya.abhishek@ibm.com> Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Author
Parents
Loading