jax
f07e963b - Simplify jaxpr for jnp.repeat in scalar repeat case.

Commit
1 year ago
Simplify jaxpr for jnp.repeat in scalar repeat case. Before: ``` In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4)) Out[2]: { lambda ; a:i32[3,4]. let b:i32[3,4,1] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 1)] a c:i32[1,3,1,4,1,1] = reshape[dimensions=None new_sizes=(1, 3, 1, 4, 1, 1)] b d:i32[1,3,1,4,3,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2, 3, 4, 5) shape=(1, 3, 1, 4, 3, 1) ] c e:i32[3,4,3] = reshape[dimensions=None new_sizes=(3, 4, 3)] d f:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] e in (f,) } ``` After: ``` In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4)) Out[2]: { lambda ; a:i32[3,4]. let b:i32[3,4,3] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 3)] a c:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] b in (c,) } ```
Author
Parents
Loading