Make `axis_data.explicit_mesh_axis` dynamic whose value depends on the current mesh context.
If you are under `auto_axes` context but `explicit_mesh_axis` was calculated under `Explicit` context, then when batching rule of primitives are called under auto_axes, `explicit_mesh_axis` needs to be `None` instead of a concrete mesh axis.
Fixes https://github.com/jax-ml/jax/issues/29839
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 778134085