flax
f31023b4 - Copybara import of the project:

Commit
4 years ago
Copybara import of the project: -- 0e280bbac078dcffe03a511f3132bc3687d39726 by George Necula <gcnecula@gmail.com>: [masking] Remove references to masking.Poly from the lax.py and lax_numpy.py Previously, in order to increase the coverage of masking we added special cases in lax.py and lax_numpy.py to avoid exceptions in presence of masking.Poly. For example: ``` if not isinstance(d, masking.Poly): if some_check(d): raise ValueError ``` All such conditionals make the code behave potentially different when tracing with masking.Poly than when tracing with concrete shapes, which makes it hard to ensure soundness. Perhaps the most eggregious was: ``` if type(i) is Poly: # dummy index if i is polynomial, doesn't matter for shape inference i = 0 ``` PiperOrigin-RevId: 367985125
Author
Committer
Parents
Loading