pytorch
768014b3 - Allow disabling cache in autocast (automatic mixed precision) (#63552)

Commit
4 years ago
Allow disabling cache in autocast (automatic mixed precision) (#63552) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63552 In this PR, we want to exclude these 2 cases in the `Autocast` weight cache usages: - Using `torch.jit.trace` under the `Autocast` As report in https://github.com/pytorch/pytorch/issues/50231 and several other discussions, using `torch.jit.trace` under the `Autocast`, the trace process would hit Autocast's weight cache and fails. So we should disable weight cache under the trace process. - Using `Autocast` with `Grad mode` - Usually we are using `Grad mode` for training. Since in the training phase, the weight will change in every step. So we doesn't need to cache the weight. - For the recommended `Autocast` training case in the [doc](https://pytorch.org/docs/stable/amp.html), `Autocast` will clear the cache every step leaving the context. We should disable it to save the clear operations. ``` model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() ``` Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D30644913 Pulled By: ezyang fbshipit-source-id: ad7bc87372e554e7aa1aa0795e9676871b3974e7
Parents
Loading