handle mapped_invars correctly in more places
fixes #2822
We didn't handle mapped_invars correctly in all places in #1959. In
particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only
populated after partial_eval and set to None otherwise (i.e.
staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new
inputs corresponding to nonzero tangents, and hence `mapped_invars`
must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds
inputs to the staged-out version of the primitive,
4. didn't forward it correctly in JaxprTrace.process_map anyway (we
were setting it to all-true for the staged out eqn),
5. removed the leading axes of all pvs in JaxprTrace.process_map
regardless of whether the corresponding entry of `mapped_invars`
was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing
control flow (e.g. scan or remat) of pmap involving closed-over tracers
(apparently a rare case), since that's the case where we first form a
jaxpr (populating `mapped_invars`) and then later have to apply
transformations like AD and further partial eval (thus engaging
JVPTrace.process_map and JaxprTrace.process_map with a populated
`mapped_invars` parameter). It worked in other cases, e.g. when the pmap
was not inside control flow or a remat, because in those cases we left
`mapped_invars` set to None, indicating all-true of any length (so it
didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional (even though the default value
of None is convenient as long as it lasts, it's not worth the
complexity of handling the two None-or-populated cases everywhere
downstream),
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents
effectively prunes inputs, and having undefined-primal args also
prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args
prunes inputs)
3. making the separate cases of calls and maps handled more explicity
by adding a new Primitive.map_primitive boolean attribute
(analogous to Primitive.call_primitive).
This is begging for a more coherent cleanup. For example, we reuse the
same Primitive class but tag it with `call_primitive` or `map_primitive`
(only one of which can be True); we should instead just have a separate
Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or
`map_primitive=True` implies things about what `params` must be present
(`call_jaxpr` and `mapped_invars`). I plan to follow up with those
cleanups, but I wanted to get something working first.