pytorch
d30db9a2 - Replace non-reentrant checkpoint with a rewrite that can be nested and contain grad (#90105)

Commit
1 year ago
Replace non-reentrant checkpoint with a rewrite that can be nested and contain grad (#90105) Changes: - bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: - Accessing _saved_tensors multiple times will silently recompute forward multiple times. - Accessing ctx.saved_tensor twice in the same backward will now raise an error. - To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default. Before land: - import to check for more bc-breakingness - implement any workarounds for the bc-breaking-ness, if we decide on any - update docs to reflect new lifetime of recomputed variables - update docs to mention the early stop feature Follow ups: - enable early-stopping by default - update docs/tutorial to feature nested use cases Related docs: - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448 - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit# - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit Pull Request resolved: https://github.com/pytorch/pytorch/pull/90105 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading