jax
e854f165 - Allow `P.UNCONSTRAINED` in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.

Commit
1 year ago
Allow `P.UNCONSTRAINED` in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs. This also simplifies the `impl` rule of `sharding_cast`. PiperOrigin-RevId: 707349491
Author
Parents
Loading