jax
0e68e5f4 - Use the cotangent spec when calculating which mesh axes to do psum over if check_vma=False in shard_map transpose.

Commit
3 days ago
Use the cotangent spec when calculating which mesh axes to do psum over if check_vma=False in shard_map transpose. This was a bug where we were inserting un-necessary psums because of the wrong pspec in transpose. PiperOrigin-RevId: 920421378
Author
Parents
Loading