flax
8efa9975 - Add propagation of param_dtype to carry initializer.

Commit
2 years ago
Add propagation of param_dtype to carry initializer. This fixes sudden dtype changes when using jax_enable_x64 with RNNs. PiperOrigin-RevId: 545219222
Author
Committer
Parents
Loading