jax
b12e491f - Validate that an explicit axis_size passed to vmap() matches the sizes of the mapped axes.

Commit
4 days ago
Validate that an explicit axis_size passed to vmap() matches the sizes of the mapped axes. It seems we weren't checking this. PiperOrigin-RevId: 892515493
Author
Parents
Loading