Delete cotangent references on their last use (#2719)
* Delete cotangent references on their last use
Current implementation of transposition may add a factor of 2x to
peak memory usage in real cases and _potentially an unbounded factor_
in pathological programs. The reason why this happens is because the
cotangents computed by the `backward_pass` are never evicted from the
environment until the whole transposition is complete. Other systems
(e.g. PyTorch) generally make use of refcounting or liveness analysis
to remove unnecessary references as soon as they are known to no
longer be needed.
A simple example that showcases this issue is this:
```python
def f(x):
for i in range(1000):
x = x * 4
return x
x = np.ones(4)
vjp(f, x)[1](x)
```
Adding `print(len(ct_env))` at the end of `backward_pass` reveals that
the dictionary actually holds a thousand `DeviceArray`s, while both the
forward and backward can be computed in constant memory. Of course this
is the pathological example I mentioned above, but one can easily see
that keeping the cotangents alive for the whole duration of differentiation
causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates`
instead of `fwd_memory + program_pathwidth` where:
* `fwd_coefs` is the amount of memory necessary to store all constant
coefficients of the linearized function
* `all_fwd_intermediates` is the amount of memory necessary to
store _all intermediates appearing in the forward program_.
* `program_pathwidth` is the maximum over amounts of memory necessary
to store the live values over all transposed program locations
Note that usually we have that
`all_fwd_intermediates > fwd_coefs >> program_pathwidth`
(`>>` meaning that the RHS is usually significantly smaller).
* Import Set
* Use a list instead of a dict
* Type annotation
* Import List