flax
07e513f6 - Plumb spmd_axis_name through transforms.vmap through to JAX vmap

Commit
3 years ago
Plumb spmd_axis_name through transforms.vmap through to JAX vmap This ensure transforms.vmap matches lift.vmap following https://github.com/google/flax/pull/2390 PiperOrigin-RevId: 468746864
Author
James Lee-Thorp
Committer
Parents
Loading