jax
88c802af - Delete cotangent references on their last use (#2719)

Commit
5 years ago
Delete cotangent references on their last use (#2719) * Delete cotangent references on their last use Current implementation of transposition may add a factor of 2x to peak memory usage in real cases and _potentially an unbounded factor_ in pathological programs. The reason why this happens is because the cotangents computed by the `backward_pass` are never evicted from the environment until the whole transposition is complete. Other systems (e.g. PyTorch) generally make use of refcounting or liveness analysis to remove unnecessary references as soon as they are known to no longer be needed. A simple example that showcases this issue is this: ```python def f(x): for i in range(1000): x = x * 4 return x x = np.ones(4) vjp(f, x)[1](x) ``` Adding `print(len(ct_env))` at the end of `backward_pass` reveals that the dictionary actually holds a thousand `DeviceArray`s, while both the forward and backward can be computed in constant memory. Of course this is the pathological example I mentioned above, but one can easily see that keeping the cotangents alive for the whole duration of differentiation causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates` instead of `fwd_memory + program_pathwidth` where: * `fwd_coefs` is the amount of memory necessary to store all constant coefficients of the linearized function * `all_fwd_intermediates` is the amount of memory necessary to store _all intermediates appearing in the forward program_. * `program_pathwidth` is the maximum over amounts of memory necessary to store the live values over all transposed program locations Note that usually we have that `all_fwd_intermediates > fwd_coefs >> program_pathwidth` (`>>` meaning that the RHS is usually significantly smaller). * Import Set * Use a list instead of a dict * Type annotation * Import List
Author
Parents
Loading