jax
8d304fe7 - Simplify the jaxpr emitted for jnp.repeat in the case that the input dim size is 1.

Commit
258 days ago
Simplify the jaxpr emitted for jnp.repeat in the case that the input dim size is 1. In this case, the repeat() operation is just a broadcast. PiperOrigin-RevId: 780242309
Author
Parents
Loading