flax
718aa8cc - Make `TrainState`'s `step` possibly jax.Array. This makes `replicate` valid for type checking.

Commit
1 year ago
Make `TrainState`'s `step` possibly jax.Array. This makes `replicate` valid for type checking. PiperOrigin-RevId: 615996178
Author
Committer
Parents
Loading