jax
6ca53832
- Make sure to use `to_ct_spec()` when unsharding Zeros in shard_map_transpose. Also use `to_ct_aval` in _flatten_bwd in custom_vjp
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
49 days ago
Make sure to use `to_ct_spec()` when unsharding Zeros in shard_map_transpose. Also use `to_ct_aval` in _flatten_bwd in custom_vjp PiperOrigin-RevId: 884579658
References
#36023 - Postrelease JAX v0.9.2.
Author
yashk2810
Committer
danielsuo
Parents
ed8e9d55
Loading