flax
07e513f6
- Plumb spmd_axis_name through transforms.vmap through to JAX vmap
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
3 years ago
Plumb spmd_axis_name through transforms.vmap through to JAX vmap This ensure transforms.vmap matches lift.vmap following https://github.com/google/flax/pull/2390 PiperOrigin-RevId: 468746864
References
#2398 - Plumb spmd_axis_name through transforms.vmap through to JAX vmap
Author
James Lee-Thorp
Committer
a-googler
Parents
2b73efdd
Loading