flax
8efa9975
- Add propagation of param_dtype to carry initializer.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
2 years ago
Add propagation of param_dtype to carry initializer. This fixes sudden dtype changes when using jax_enable_x64 with RNNs. PiperOrigin-RevId: 545219222
Author
a-googler
Committer
a-googler
Parents
b05c6738
Loading