jax
02007235 - Make sure to use `to_ct_spec()` when unsharding Zeros in shard_map_transpose. Also use `to_ct_aval` in _flatten_bwd in custom_vjp

Commit
2 days ago
Make sure to use `to_ct_spec()` when unsharding Zeros in shard_map_transpose. Also use `to_ct_aval` in _flatten_bwd in custom_vjp PiperOrigin-RevId: 884579658
Author
Parents
Loading