fix a leak where compiled results lived too long
The original repro @levskaya showed us was essentially this OOM:
for i in range(40):
f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
f().block_until_ready()
Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.
Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).
The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.
One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.
This commit also incidentally fixed #1600.