[functorch] maxpool_2d_with_indices_backward batch rule for specific case
This gives us coverage on all the batch rules for the cifar10 dp
example. We're slightly slower than opacus though, the numbers on my
machine are:
- 4 it/s for functorch vmap+grad
- 4.2 it/s for opacus
The differential should be investigated and I am also not sure if the
benchmarks are comparing the right things.