jax
6ca53832 - Make sure to use `to_ct_spec()` when unsharding Zeros in shard_map_transpose. Also use `to_ct_aval` in _flatten_bwd in custom_vjp

Commit
49 days ago
Make sure to use `to_ct_spec()` when unsharding Zeros in shard_map_transpose. Also use `to_ct_aval` in _flatten_bwd in custom_vjp PiperOrigin-RevId: 884579658
Author
Committer
Parents
Loading