Simplify jaxpr for jnp.repeat in scalar repeat case.
Before:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
b:i32[3,4,1] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 1)] a
c:i32[1,3,1,4,1,1] = reshape[dimensions=None new_sizes=(1, 3, 1, 4, 1, 1)] b
d:i32[1,3,1,4,3,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1, 2, 3, 4, 5)
shape=(1, 3, 1, 4, 3, 1)
] c
e:i32[3,4,3] = reshape[dimensions=None new_sizes=(3, 4, 3)] d
f:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] e
in (f,) }
```
After:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
b:i32[3,4,3] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 3)] a
c:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] b
in (c,) }
```