use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… (#39412)
* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* ci failure fix
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* add check
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix ci failure
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* refine code, extend to cuda
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* refine code
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix review comments
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* refine the PR
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
---------
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>