DeepSpeed
6e2899fb - WA for Torch-compile-Z3-act-apt accuracy issue from the Pytorch repo (#5590)

Comment changes are shownComment changes are hidden
Commit
1 year ago
WA for Torch-compile-Z3-act-apt accuracy issue from the Pytorch repo (#5590) We have been encountered an accuracy issue when running Torch compile + zero3 + activation checkpointing. Specifically some grads gets is zeroed (running without torch compile, this issue is not encountered). This issue was also reproduced by Umesh Chand from the DS team. We found that in the Pytorch repo torch compile has been specifically disabled using the label: @torch._disable_dynamo() reference to the WA in the Pytorch repo (https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324) this indicates that there is some issue with torch compile and checkpointing (not necessarily DS related). given that the checkpointing function in DeepSpeed is based on the Pytorch function, We propose to adopt this WA to ensure correct behavior (it can be removed later if the underlying issue is fixed) Note: this shouldn't impact non-troch compile cases. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Parents
  • deepspeed/runtime/activation_checkpointing
    • File
      checkpointing.py