jax
ebb75db8 - [sharding_in_types] Add `out_type` argument to `einsum` and `dot_general` to allow specifying for the output type. Right now, it only accept a `NamedSharding` but in the future we can allow a polymorphic type of: `jax.ShapeDtypeStruct | Sharding | Layout`.

Commit
1 year ago
[sharding_in_types] Add `out_type` argument to `einsum` and `dot_general` to allow specifying for the output type. Right now, it only accept a `NamedSharding` but in the future we can allow a polymorphic type of: `jax.ShapeDtypeStruct | Sharding | Layout`. PiperOrigin-RevId: 688399552
Author
Parents
Loading