pytorch
d798de2b - Checkpoint CUDA Allocator Private Pool State (#94653)

Commit
1 year ago
Checkpoint CUDA Allocator Private Pool State (#94653) Copying note from cuda caching allocator: ``` * Note [Checkpointing PrivatePoolState] * * Refer above to Note [Interaction with CUDA graph capture]. Allocations made * during graph capture are made from a separate private pool. During graph * capture allocations behave as usual. During graph replay the allocator * state does not change even as new tensors are created. The private pool * will not free its blocks to the main caching allocator until cuda graph use * is finished to prevent an allocation from eager clobbering the memory from * a live but unaccounted for tensor that was created during replay. * * `make_graphed_callables`, a series of separate callables chained in * successive cuda graphs, can share a memory pool because after a cuda graph * recording the allocations in the shared private pool exactly reflect the * tensors that are allocated. * * We would like to extend callable chaining to support a graphed callable * tree. In this scenario, we have a tree of callable chains which will be * captured with cuda graphs. In the diagram below, we have a tree with four * callables, A, B, C, and D. Suppose we have captured, and subsequently * replayed, A, B, and C. Then on a new invocation, we replay A and B, but * would now like to record D. At this point the private pool will not reflect * any of the live tensors created during graph replay. Allocations made * during a new recording with the pool could overwrite those live tensors. * * In order to record a new graph capture after replaying prior callables in * the tree, we need the allocator to reflect the state of the live tensors. * We checkpoint the state of the private after each recording, and then * reapply it when we are starting a new recording chain. Additionally, we * must free the allocations for any tensors that died between the end of our * previous graph replaying and our new recording (TODO). All of the allocated * segments that existed in the checkpointed state must still exist in the * pool. There may also exist new segments, which we will free (TODO : link * note [live tensors between iterations] when it exists). * * * ---------------> A ---------------> B ---------------> C * | * | * | * | * ---------------> D ``` A few TODOs: - need to add logic for freeing tensors that have died between a last replay and current new recording - Add logic for free that might be called on a pointer multiple times (because we are manually freeing live tensors) The two scenarios above have not been exercised in the tests yet. Differential Revision: [D43999889](https://our.internmc.facebook.com/intern/diff/D43999889) Pull Request resolved: https://github.com/pytorch/pytorch/pull/94653 Approved by: https://github.com/zdevito
Author
Committer
Parents
Loading