pytorch
83b132b1 - [pruner] add support for pruning BatchNorm2d (#63519)

Commit
4 years ago
[pruner] add support for pruning BatchNorm2d (#63519) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63519 If the pruner should be pruning biases along with weights, then if the model has BatchNorm2d following pruned Conv2d layers, then the corresponding channels of the BatchNorm must also be pruned. Specifically, they need to zeroed out, rather than fully removed, since in eager mode, the dimensions between layers need to be preserved. To do this, we add a pruning parametrization called `ZeroesParametrization` which zeroes out pruned channels, rather than removing them. The user must provide in the config, a tuple of the Conv2d and BatchNorm layers that go together. The `prepare` method will add the tuple to the `module_groups`; then it will add a PruningParametrization to the Conv2d layer, and a ZeroesParametrization to BatchNorm, and then set their pruned sets to be the same set. That way, during `step`, both masks are updated with the same pruned indices. ghstack-source-id: 136562278 Test Plan: `buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner` https://pxl.cl/1N1P6 Reviewed By: z-a-f Differential Revision: D30349855 fbshipit-source-id: 3199d3688d5a70963f9b32d7a8fdac3962ae6a65
Author
Parents
Loading