avoid some trival operations in mask computations (#3364)
I think #2800 accidentally removed `mul`, `pow`, and `prod` from
mask.py, but those were there to avoid some trivial computations. In
particular, on this program:
```python
@partial(mask, in_shapes=['n'], out_shape='')
def foo(x):
return np.sum(x)
padded_x = np.array([0, 1, 2, 3, 999, 999])
print(make_jaxpr(foo)([padded_x], dict(n=3)))
```
after #2800 (and until this commit) we'd print
```
{ lambda c h ; a b.
let d = mul b 1
e = mul d 1
f = add e 0
g = let c f
i = select g a h
j = reduce_sum [ axes=(0,) ] i
in (j,) }
```
but before #2800 (and after this commit) we print
```
{ lambda c e ; a b.
let d = lt c b
f = select d a e
g = reduce_sum[ axes=(0,) ] f
in (g,) }
```
This might save a tiny bit of work, but also it means the make_jaxpr
results are cleaner, and we like to show those!
@j-towns spotted the fix
cc @JuliusKunze