flax
c87b54b7 - Add support for non-float32 normalization for linen normalization layers.

Commit
3 years ago
Add support for non-float32 normalization for linen normalization layers. Previously we always casted to float32 even if this is a downcast. For example f64 -> f32 or complex -> float. This change maxes the reduction dtype *at least* float32
Author
Committer
Parents
Loading