jax
718ac012 - Couple of changes in this PR:

Commit
72 days ago
Couple of changes in this PR: * Add a test for grad(shmap) where we have array-like HiVal, HiPrimitive, HiType, HipSpec. * Add `to_cotangent_spec` to `HipSpec`. * Add `nospec` to `HiType` (part of shard_map interface) so that we can get a `HipSpec` from the underlying vmas. There's more subtlety here like under check_vma you will get all_manual_names passed in so you can give us a HipSpec wrt to that. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 878185697
Author
Parents
Loading