fix vmap-of-grad-of-shmap axis name reuse bug
When we write `vmap(f, spmd_axis_name=A)`, we require that `f` does not mention
A in specs, like the `PartitionSpec` in a `with_sharding_constraint` or the
`in_specs`/`out_specs` of `shard_map`. Previously, shard_map autodiff violated
that requirement, since we gave residuals sharding over all mesh axes (i.e.
including axis name A present in the mesh). As a result, the vmap rule could
then insert a redundant appearance of A.
This commit fixes the problem by only sharding over mesh axes mentioned in
in_specs; residuals can at most be sharded over those mesh axes. Then the vmap
rule is free to introduce an occurrence of A in the specs.