jax
b7b26457 - fix vmap-of-grad-of-shmap axis name reuse bug

Commit
1 year ago
fix vmap-of-grad-of-shmap axis name reuse bug When we write `vmap(f, spmd_axis_name=A)`, we require that `f` does not mention A in specs, like the `PartitionSpec` in a `with_sharding_constraint` or the `in_specs`/`out_specs` of `shard_map`. Previously, shard_map autodiff violated that requirement, since we gave residuals sharding over all mesh axes (i.e. including axis name A present in the mesh). As a result, the vmap rule could then insert a redundant appearance of A. This commit fixes the problem by only sharding over mesh axes mentioned in in_specs; residuals can at most be sharded over those mesh axes. Then the vmap rule is free to introduce an occurrence of A in the specs.
Author
Committer
Parents
Loading