DeepSpeed
42c1e916 - feat(activation_checkpointing): add `non_reentrant_checkpoint` to support inputs require no grad (#4118)

Comment changes are shownComment changes are hidden
Commit
1 year ago
feat(activation_checkpointing): add `non_reentrant_checkpoint` to support inputs require no grad (#4118) * feat: add `non_reentrant_checkpoint` * feat: add missing output postprocess and change the hook to record leaf forward tensor refs * fix: make the multi_grad_hook registered after graph construction * fix: backward compatibility for multi_tensor_hook * fix: nonlocal reference error of deepspeed_saved_tensors * fix: reduce repeating hook registration * test: add test for `activation_checkpointing.checkpointing.non_reentrant_checkpoint` * Pass correct node size for ZeRO++ (#4085) * Pass correct node size * formatting --------- Co-authored-by: Connor Holmes <development@cmikeh2.me> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> * add deepspeed chat arxiv report (#4110) * add deepspeed chat arxiv report * add zeroquant v2 and fp * add selective enhencement * add ignore for 'Youn' in spell checker --------- Co-authored-by: yaozhewei <zheweiy@berkeley.edu> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> * style: change flake8 detected style missmatch * test: hack to clone the `test_activation_checkpointing` module for reuse and add regression tests * doc: explain the introduction of `non_reentrant_checkpoint` * doc: explain the test of `non_reentrant_checkpoint` --------- Co-authored-by: Connor Holmes <connorholmes@microsoft.com> Co-authored-by: Connor Holmes <development@cmikeh2.me> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Conglong Li <conglong.li@gmail.com> Co-authored-by: yaozhewei <zheweiy@berkeley.edu> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Author
Parents
  • deepspeed/runtime/activation_checkpointing
    • File
      checkpointing.py
  • tests/unit/runtime/activation_checkpointing
    • File
      test_activation_checkpointing_non_reentrant.py