jax
6fa4e041 - Move the pallas_call batching logic which removes explicit_mesh_axis_names from AxisEnv into core.py and enter into that context in BatchTrace.process_primitive

Commit
27 days ago
Move the pallas_call batching logic which removes explicit_mesh_axis_names from AxisEnv into core.py and enter into that context in BatchTrace.process_primitive PiperOrigin-RevId: 859835616
Author
Parents
Loading