jax
4db212d2 - Add `_sharding` argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.

Commit
1 year ago
Add `_sharding` argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode. This is required because `jax.nn.one_hot` calls into `broascasted_iota`. PiperOrigin-RevId: 687152343
Author
Parents
Loading