jax
49224d6c - Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.

Commit
1 year ago
Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective. Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`. PiperOrigin-RevId: 716446406
Author
Parents
Loading