jax
4db212d2
- Add `_sharding` argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
1 year ago
Add `_sharding` argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode. This is required because `jax.nn.one_hot` calls into `broascasted_iota`. PiperOrigin-RevId: 687152343
References
#24379 - Add `_sharding` argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
Author
yashk2810
Committer
Google-ML-Automation
Parents
dd542630
Loading