jax
8bcee8d4 - fix a leak where compiled results lived too long

Commit
6 years ago
fix a leak where compiled results lived too long The original repro @levskaya showed us was essentially this OOM: for i in range(40): f = jit(lambda: 1. * np.ones((300, 1024, 1024))) f().block_until_ready() Even though f was being rebound on every iteration, the cache entries corresponding to the previous iterations of the loop were sticking around. Instead, if the user drops all references to a function, we want to clear the corresponding compilation cache entries (since they can never be used). The fix here is to use a two-level cache for compiled code: the first level is a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun, and the second level is a regular dictionary keyed by (transforms, params, args). Because this logic is now present in linear_util.py:cache, the implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous now. One unintended consequence is that this implementation now avoids using fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to implement this logic in pure Python. We might want to revise this for performance reasons. This commit also incidentally fixed #1600.
Author
Committer
Parents
Loading