jax
9d220373 - Export `lower_with_sharding_in_types` via jax.extend.mlir.

Commit
107 days ago
Export `lower_with_sharding_in_types` via jax.extend.mlir. Helps users writing lowering rules for custom primitives to preserve sharding information when sharding-in-types/explicit axes are being used. PiperOrigin-RevId: 868812471
Parents
Loading