jax
d5e5b42d
- Use consistent dtype for forward and backwards in jax.nn.dot_product_attention.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
341 days ago
Use consistent dtype for forward and backwards in jax.nn.dot_product_attention. Fixes https://github.com/jax-ml/jax/issues/24047 PiperOrigin-RevId: 728613700
References
#240 - CI: 02/25/25 upstream sync
Author
sbodenstein
Committer
Google-ML-Automation
Parents
44872296
Loading