pytorch
799bc645 - [Expanded Weights] fix loss reduction (#80892)

Commit
3 years ago
[Expanded Weights] fix loss reduction (#80892) Two changes in here: (1) Changes `call_for_per_sample_grads` to be curried. Old call looks like: `call_for_per_sample_grads(module, batch_size, args, kwargs)` New call looks like: `call_for_per_sample_grads(module, batch_size, loss_reduction=loss_reduction)(args, kwargs)` (2) Adds the ability to specify a loss reduction, to match what is done in Opacus. Opacus has a more complete explanation but essentially, they want the per sample gradient behavior to match what is happens in a for loop with a single example. This gets messed up if you use a mean reduction at the end since in a batch that ends up scaling all the grad_outputs by 1/batch_size, so we offset that by scaling all the grad_samples by batch_size if the loss_reduction is mean Pull Request resolved: https://github.com/pytorch/pytorch/pull/80892 Approved by: https://github.com/zou3519
Author
samdow
Committer
Parents
Loading