flax
4e2becd8 - Preserve sharding information in axes_scan

Commit
239 days ago
Preserve sharding information in axes_scan Also, remove jnp.array call in normalization.py, to preserve sharding information PiperOrigin-RevId: 780954856
Committer
Parents
Loading