T5 Gradient Checkpointing (#11353)
* Implement gradient checkpoinging for T5Stack
* A bit more robust type checking
* Add `gradient_checkpointing` to T5Config
* Formatting
* Set requires_grad only when training
* None return value will only cause problems when training
* Change the output tuple according to `use_cache`
* Enable gradient checkpointing for the decoder
Squashed commit of the following:
commit 658bdd0bd1215353a8770f558bda2ea69a0ad0c7
Author: Ceshine Lee <shuanck@gmail.com>
Date: Sat Apr 24 14:08:17 2021 +0800
Only set `require_grad` for gradient checkpointing
commit acaeee6b2e675045fb28ce2176444c1d63e908bd
Author: Ceshine Lee <shuanck@gmail.com>
Date: Sat Apr 24 13:59:35 2021 +0800
Make gradient checkpointing work with the decoder
* Formatting