jax
ad9b6d4d - implement lazy sublanguage

Commit
6 years ago
implement lazy sublanguage Before this commit, this computation would avoid materializing the iota array at trace time: @jit def f(x): m, n = x.shape return x + np.arange(n) But this one would materialize the iota array at trace time and stage it into the computation as a potentially large array constant: @jit def f(x): m, n = x.shape return x + np.arange(m)[:, None] The difference is that previously operations like broadcasts, transposes, and reshapes that add singleton dimensions (as above) would force otherwise lazy values to be materialized, while after this commit broadcasts, transposes, and reshapes are all lazy operations that only update metadata on their input rather than compiling and executing XLA computations and producing new buffers. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). This commit replaces the ad-hoc "lazy device constant" system, which was used to get the simpler behavior in the first example above. Incidentally fixes #1431 See https://github.com/google/jax/pull/1668 for more.
Author
Committer
Parents
Loading