pytorch
3d74fd48 - [Expanded Weights] add ability to not specify batch size (#80944)

Commit
3 years ago
[Expanded Weights] add ability to not specify batch size (#80944) Opacus has been asking for the ability to not specify a batch size. Previously a user had to do `call_for_per_sample_grads(module, batch_size)(*args, **kwargs)` They rightfully pointed out that in most cases when you're passing a single argument to a module's forward function, it seems repetitive to specify the batch_size. The argument here is that in cases where a user was passing more than one argument, we might not know what the batch size is if they don't match. So, this lets a user not specify a batch size (or pass it as None), meaning that `call_for_per_sample_grad(linear_module)(torch.randn(5, 4))` now works and has a batch size of 5 If there are multiple tensor arguments with different batch sizes, we fail, even if one of the inputs wouldn't have been used by the module because we can't tell which batch size we should be using. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80944 Approved by: https://github.com/zou3519
Author
samdow
Committer
Parents
Loading