jax
1fcebbaa - fix reference cycle in jaxpr tracing using weakrefs

Commit
6 years ago
fix reference cycle in jaxpr tracing using weakrefs As one step in tracing user code to a jaxpr using the machinery in partial_eval.py, we construct a bipartite graph made of JaxprTracer nodes, corresponding to values in the user code, and recipe nodes ,particularly those corresponding to jaxpr equations, representing primitive operations. (This representation was put in place in #1224, since when primitives only had single outputs we could identify each primitive operation with the JaxprTracer value it produced.) This graph had reference cycles because each equation recipe points to both its input and output tracers (as a jaxpr eqn has both input and output vars) and a tracer must be able to point to the equation recipe that produced it (for us to toposort the graph from in_tracers to out_tracers in tracers_to_jaxpr). Those cycles caused memory leaks. This commit removes the strong reference cycle using weakrefs. In particular, equation recipes only hold weak references to their output tracers. Before this change, we used the core.JaxprEqn struct both to represent equations in jaxprs (where invars and outvars are instances of the core.Var class) and to represent equation recipes (where invars and outvars are instances of the partial_eval.JaxprTracer class). That was a bit lazy. This commit distinguishes the two as separate JaxprEqn and JaxprEqnRecipe structs. Bug find and test code from @trevorcai. Thanks!
Author
Committer
Parents
Loading