Relax dimension ordering rules for dot_general.
JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions.
Since batch dimensions are now allowed in arbitrary positions, it's not hard to
generalize the dot_general batching rule to avoid performing any transposes
(#2972).
In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`.