jax
3b3e7a29 - avoid some trival operations in mask computations (#3364)

Commit
5 years ago
avoid some trival operations in mask computations (#3364) I think #2800 accidentally removed `mul`, `pow`, and `prod` from mask.py, but those were there to avoid some trivial computations. In particular, on this program: ```python @partial(mask, in_shapes=['n'], out_shape='') def foo(x): return np.sum(x) padded_x = np.array([0, 1, 2, 3, 999, 999]) print(make_jaxpr(foo)([padded_x], dict(n=3))) ``` after #2800 (and until this commit) we'd print ``` { lambda c h ; a b. let d = mul b 1 e = mul d 1 f = add e 0 g = let c f i = select g a h j = reduce_sum [ axes=(0,) ] i in (j,) } ``` but before #2800 (and after this commit) we print ``` { lambda c e ; a b. let d = lt c b f = select d a e g = reduce_sum[ axes=(0,) ] f in (g,) } ``` This might save a tiny bit of work, but also it means the make_jaxpr results are cleaner, and we like to show those! @j-towns spotted the fix cc @JuliusKunze
Author
Parents
Loading