transformers
fa22b569 - :rotating_light: Fix gradient checkpointing for several models and improve test robustness (#41818)

Commit
34 days ago
:rotating_light: Fix gradient checkpointing for several models and improve test robustness (#41818) * Implement gradient checkpointing in GPTBigCode Support for gradient checkpointing was lost in the major refactoring in PR #38635 and this is the attempt to re-add it. I extended the tests to - test `use_reentrant=True` and `False` - make sure `model.train` is called so that gradient checkpointing works; this is a limiation of the tests currently used by GPTBigCode - make sure that one (the first) gradient checkpointing layer is called - make sure that the same non-zero grads are there for normal and checkpointing runs - this is something we tripped over before in PEFT due to the possibly incompletely stored runtime environment in the checkpointed forward step, see also peft#2826 Note that the invocation of `GPTBigCodeBlock.forward` has changed: - `layer_past` is now passed as a keyword argument so that `GradientCheckpointingLayer.__call__` can see and filter this parameter (`use_reentrant=False` fails otherwise) - `{encoder_}hidden_states` are still passed as positional arguments so that `torch.utils.checkpoint.checkpoint` receives them as pos. args and computes gradients for these (kwargs would be filtered by `GradientCheckpointingLayer`). * Improve gradient checkpointing tests - Compare that the non-zero gradients in a reference run are present in the checkpointing run - Make sure that the forward of at least one gradient checkpointing layer is actually called more than once (as expected during gradient checkpointing backward) Currently there are some problems with Bert-derived MultipleChoice models, when dropout is enabled there are scenarios during gradient checkpointing where `classifier.bias.grad` is None. I don't yet have a good explanation for this, disabling dropout resolves this. I would have understood, if it is dropout on the classification layer but enabling attention dropout is also leading to this behavior. MoE models have selective sparsity depending on the selected experts, for this reason we only compare gradients on parameters collected on the reference backward run. * Remove duplicated gradient checkpointing code * Address review comments * Make test output consistent * GradientCheckpointingLayer for xlstm, zamba, zamba2 * GradientCheckpointingLayer for swiftformer also drop janus from ignore list - only the VQVAE case is without gradient checkpointing and it is doubtful that it is usefule in that case. Training with gradient checkpointing is not tested anyway. * Make an exception for CLVP The implementation of GradientCheckpointingLayers is not trivial and may break behavior that was previously expected. Therefore we keep it as-is for now. * Remove unneeded exceptions --------- Co-authored-by: nemo <git@ningu.net> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Author
Parents
Loading