jax
50476ee1 - Show all mismatches in `flatten_axes_resources()` for pytree prefix errors instead of just the first mismatch. This prevents multiple roundtrips just to identify all pytree mismatches one by one.

Commit
61 days ago
Show all mismatches in `flatten_axes_resources()` for pytree prefix errors instead of just the first mismatch. This prevents multiple roundtrips just to identify all pytree mismatches one by one. PiperOrigin-RevId: 880644317
Author
Parents
Loading