jax
2ae653f4 - [pallas] Sketched an `mpmd_map` primitive

Commit
75 days ago
[pallas] Sketched an `mpmd_map` primitive The new primitive is similar to `core_map` in that it it allows to map a function over a mesh. However, unlike `core_map` * it accepts multiple mesh-function pairs; * it does not currently support closed-over values, including refs, and * it uses value semantics, so the primitive accepts and returns values. Despite the name, though, there is nothing MPMD about it just yet. The lowering only supports a single mesh-function. This limitation will be lifeted if follow up changes. The primitive is used in the `default_mesh_discharge_rule`. PiperOrigin-RevId: 871266518
Author
Parents
Loading