pytorch
73c9911f - always realize output regardless of the number of reads (#88046)

Commit
2 years ago
always realize output regardless of the number of reads (#88046) This improves hf_Bert 1.139x->1.21x, currently lowmem dropout doesn't work for nn.Dropout module, and before this change we were recomputing all the dropout masks in a very inefficient kernel. This change pushes dropout masks to be saved in the dropout kernels where they are first computed. cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx Pull Request resolved: https://github.com/pytorch/pytorch/pull/88046 Approved by: https://github.com/Chillee
Author
Natalia Gimelshein
Committer
Parents
Loading