flax
Add out_sharding argument to call methods for layers with jax calls that support it
#5102
Merged

Loading