jax
d5e5b42d - Use consistent dtype for forward and backwards in jax.nn.dot_product_attention.

Commit
341 days ago
Use consistent dtype for forward and backwards in jax.nn.dot_product_attention. Fixes https://github.com/jax-ml/jax/issues/24047 PiperOrigin-RevId: 728613700
Author
Parents
Loading