[Expanded Weights] add ability to not specify batch size (#80944)
Opacus has been asking for the ability to not specify a batch size. Previously a user had to do
`call_for_per_sample_grads(module, batch_size)(*args, **kwargs)`
They rightfully pointed out that in most cases when you're passing a single argument to a module's forward function, it seems repetitive to specify the batch_size. The argument here is that in cases where a user was passing more than one argument, we might not know what the batch size is if they don't match. So, this lets a user not specify a batch size (or pass it as None), meaning that
`call_for_per_sample_grad(linear_module)(torch.randn(5, 4))`
now works and has a batch size of 5
If there are multiple tensor arguments with different batch sizes, we fail, even if one of the inputs wouldn't have been used by the module because we can't tell which batch size we should be using.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80944
Approved by: https://github.com/zou3519