Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932)
* fix: fixes for graph breaks
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix: formatting
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix: import error
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix: Add Fa2Kwargs
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix: PR Changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* Revert "PR changes"
This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4.
* PR changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix: FlashAttentionKwarg
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix: FlashAttentionKwarg
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR Changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR Changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR Changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR Changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* PR Changes
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* addition of documentation
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* change in _flash_attention_forward
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* make fix-copies
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* revert make fix-copies
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
* fix copies
* style
* loss kwargs typing
* style and pull latest changes
---------
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>