jax
e2e73a85 - Relax dimension ordering rules for dot_general.

Commit
5 years ago
Relax dimension ordering rules for dot_general. JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`.
Author
Committer
Parents
Loading