jax
9b12763b - revive the tracer leak checker

Commit
5 years ago
revive the tracer leak checker The tracer leak checker never worked with the jit (or pmap) compilation cache because 1. it relies on Python reference counting (via a weakref mechanism) to check that there are no more references to a MasterTrace (e.g. from any Tracers associated with it) once a trace is finished, but 2. the compilation caches (i.e. linear_util.cache) can include in their cache key the transforms stacked on the corresponding WrappedFun being jitted, and transforms (specifically trace_to_subjaxpr in partial_eval.py) can include MasterTraces. Hence the cache keys included references to the MasterTraces, defeating the leak checking mechanism. This commit just makes an equal copy of any MasterTraces when building the cache key. MasterTraces are already hashable, with equality defined based on just their level and trace type. MasterTraces are only compared by identity in core.full_raise, and then only to determine if a sublift is needed (because a trace is encountering one of its own tracers but from inside at least one additional level of call scoping). That's not an issue for jit because of the caching, but it could be another issue for calls; TODO figure that out. Co-authored-by: James Bradbury <jekbradbury@google.com>
Author
Committer
Parents
Loading