add post unroll optimizations (#36828)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36828
This changes ir complexity for the following:
```
("Name", "Ifs/Loops", "non-tensor ops")
Before: ('max_unpool1d', 0, 12)
After: ('max_unpool1d', 0, 3)
Before: ('max_unpool2d', 0, 22)
After: ('max_unpool2d', 0, 3)
Before: ('max_unpool3d', 0, 33)
After: ('max_unpool3d', 0, 4)
Before: ('adaptive_max_pool2d', 0, 6)
After: ('adaptive_max_pool2d', 0, 3)
Before: ('adaptive_max_pool3d', 0, 9)
After: ('adaptive_max_pool3d', 0, 4)
Before: ('adaptive_avg_pool2d', 0, 6)
After: ('adaptive_avg_pool2d', 0, 3)
Before: ('adaptive_avg_pool3d', 0, 9)
After: ('adaptive_avg_pool3d', 0, 4)
Before: ('instance_norm', 1, 6)
After: ('instance_norm', 0, 0)
Before: ('group_norm', 1, 6)
After: ('group_norm', 0, 0)
Before: ('upsample', 13, 71)
After: ('upsample', 13, 68)
Before: ('upsample', 13, 71)
After: ('upsample', 13, 68)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 13, 70)
After: ('interpolate', 13, 67)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 13, 70)
After: ('interpolate', 13, 67)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 13, 70)
After: ('interpolate', 13, 67)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 13, 70)
After: ('interpolate', 13, 67)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 13, 58)
After: ('interpolate', 13, 57)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 13, 58)
After: ('interpolate', 13, 57)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 13, 58)
After: ('interpolate', 13, 57)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 13, 82)
After: ('interpolate', 13, 77)
Before: ('interpolate', 14, 82)
After: ('interpolate', 14, 77)
Before: ('interpolate', 14, 82)
After: ('interpolate', 14, 77)
Before: ('interpolate', 13, 82)
After: ('interpolate', 13, 77)
Before: ('interpolate', 14, 82)
After: ('interpolate', 14, 77)
Before: ('interpolate', 14, 82)
After: ('interpolate', 14, 77)
Before: ('interpolate', 13, 82)
After: ('interpolate', 13, 77)
Before: ('interpolate', 14, 82)
After: ('interpolate', 14, 77)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 14, 71)
After: ('interpolate', 14, 68)
Before: ('interpolate', 15, 106)
After: ('interpolate', 15, 103)
Before: ('interpolate', 14, 73)
After: ('interpolate', 14, 70)
Before: ('interpolate', 15, 106)
After: ('interpolate', 15, 103)
Before: ('interpolate', 14, 73)
After: ('interpolate', 14, 70)
Before: ('interpolate', 15, 92)
After: ('interpolate', 15, 91)
Before: ('interpolate', 14, 60)
After: ('interpolate', 14, 59)
Before: ('interpolate', 15, 94)
After: ('interpolate', 15, 93)
Before: ('interpolate', 14, 62)
After: ('interpolate', 14, 61)
Before: ('interpolate', 15, 116)
After: ('interpolate', 15, 111)
Before: ('interpolate', 14, 82)
After: ('interpolate', 14, 77)
Before: ('interpolate', 15, 118)
After: ('interpolate', 15, 113)
Before: ('interpolate', 14, 84)
After: ('interpolate', 14, 79)
Before: ('test_nn_BatchNorm1d_3d_input', 3, 9)
After: ('test_nn_BatchNorm1d_3d_input', 2, 3)
Before: ('test_nn_BatchNorm1d_3d_input_not_affine', 3, 9)
After: ('test_nn_BatchNorm1d_3d_input_not_affine', 2, 3)
Before: ('test_nn_BatchNorm1d_zero_batch', 3, 9)
After: ('test_nn_BatchNorm1d_zero_batch', 2, 3)
Before: ('test_nn_BatchNorm2d', 3, 13)
After: ('test_nn_BatchNorm2d', 2, 3)
Before: ('test_nn_BatchNorm2d_2d_simple_average', 3, 15)
After: ('test_nn_BatchNorm2d_2d_simple_average', 2, 5)
Before: ('test_nn_BatchNorm2d_momentum', 3, 13)
After: ('test_nn_BatchNorm2d_momentum', 2, 3)
Before: ('test_nn_BatchNorm2d_not_affine', 3, 13)
After: ('test_nn_BatchNorm2d_not_affine', 2, 3)
Before: ('test_nn_BatchNorm2d_not_tracking_stats', 1, 10)
After: ('test_nn_BatchNorm2d_not_tracking_stats', 0, 0)
Before: ('test_nn_BatchNorm2d_zero_batch', 3, 13)
After: ('test_nn_BatchNorm2d_zero_batch', 2, 3)
Before: ('test_nn_BatchNorm3d', 3, 17)
After: ('test_nn_BatchNorm3d', 2, 3)
Before: ('test_nn_BatchNorm3d_3d_simple_average', 3, 19)
After: ('test_nn_BatchNorm3d_3d_simple_average', 2, 5)
Before: ('test_nn_BatchNorm3d_momentum', 3, 17)
After: ('test_nn_BatchNorm3d_momentum', 2, 3)
Before: ('test_nn_BatchNorm3d_not_affine', 3, 17)
After: ('test_nn_BatchNorm3d_not_affine', 2, 3)
Before: ('test_nn_BatchNorm3d_not_tracking_stats', 1, 14)
After: ('test_nn_BatchNorm3d_not_tracking_stats', 0, 0)
Before: ('test_nn_BatchNorm3d_zero_batch', 3, 17)
After: ('test_nn_BatchNorm3d_zero_batch', 2, 3)
Before: ('test_nn_InstanceNorm1d', 1, 6)
After: ('test_nn_InstanceNorm1d', 0, 0)
Before: ('test_nn_InstanceNorm1d_tracking_stats', 1, 6)
After: ('test_nn_InstanceNorm1d_tracking_stats', 0, 0)
Before: ('test_nn_InstanceNorm2d', 1, 10)
After: ('test_nn_InstanceNorm2d', 0, 0)
Before: ('test_nn_InstanceNorm2d_tracking_stats', 1, 10)
After: ('test_nn_InstanceNorm2d_tracking_stats', 0, 0)
Before: ('test_nn_InstanceNorm3d', 1, 14)
After: ('test_nn_InstanceNorm3d', 0, 0)
Before: ('test_nn_InstanceNorm3d_tracking_stats', 1, 14)
After: ('test_nn_InstanceNorm3d_tracking_stats', 0, 0)
Before: ('test_nn_GroupNorm_1d_affine', 1, 6)
After: ('test_nn_GroupNorm_1d_affine', 0, 0)
Before: ('test_nn_GroupNorm_1d_no_affine_IN', 1, 6)
After: ('test_nn_GroupNorm_1d_no_affine_IN', 0, 0)
Before: ('test_nn_GroupNorm_1d_no_affine_LN', 1, 6)
After: ('test_nn_GroupNorm_1d_no_affine_LN', 0, 0)
Before: ('test_nn_GroupNorm_2d_affine', 1, 10)
After: ('test_nn_GroupNorm_2d_affine', 0, 0)
Before: ('test_nn_GroupNorm_2d_no_affine_IN', 1, 10)
After: ('test_nn_GroupNorm_2d_no_affine_IN', 0, 0)
Before: ('test_nn_GroupNorm_2d_no_affine_LN', 1, 10)
After: ('test_nn_GroupNorm_2d_no_affine_LN', 0, 0)
Before: ('test_nn_AdaptiveMaxPool2d_single', 0, 6)
After: ('test_nn_AdaptiveMaxPool2d_single', 0, 3)
Before: ('test_nn_AdaptiveMaxPool2d_tuple', 0, 6)
After: ('test_nn_AdaptiveMaxPool2d_tuple', 0, 3)
Before: ('test_nn_AdaptiveMaxPool3d_single', 0, 9)
After: ('test_nn_AdaptiveMaxPool3d_single', 0, 4)
Before: ('test_nn_AdaptiveMaxPool3d_tuple', 0, 9)
After: ('test_nn_AdaptiveMaxPool3d_tuple', 0, 4)
Before: ('test_nn_AdaptiveMaxPool3d_single_nonatomic', 0, 9)
After: ('test_nn_AdaptiveMaxPool3d_single_nonatomic', 0, 4)
Before: ('test_nn_AdaptiveMaxPool3d_tuple_nonatomic', 0, 9)
After: ('test_nn_AdaptiveMaxPool3d_tuple_nonatomic', 0, 4)
Before: ('test_nn_AdaptiveAvgPool2d_single', 0, 6)
After: ('test_nn_AdaptiveAvgPool2d_single', 0, 3)
Before: ('test_nn_AdaptiveAvgPool2d_single_1x1output', 0, 6)
After: ('test_nn_AdaptiveAvgPool2d_single_1x1output', 0, 3)
Before: ('test_nn_AdaptiveAvgPool2d_tuple', 0, 6)
After: ('test_nn_AdaptiveAvgPool2d_tuple', 0, 3)
Before: ('test_nn_AdaptiveAvgPool3d_single', 0, 9)
After: ('test_nn_AdaptiveAvgPool3d_single', 0, 4)
Before: ('test_nn_AdaptiveAvgPool3d_tuple', 0, 9)
After: ('test_nn_AdaptiveAvgPool3d_tuple', 0, 4)
```
Test Plan: Imported from OSS
Differential Revision: D21160759
Pulled By: eellison
fbshipit-source-id: 91ca6ef2269ee364ca354c8d0843847744145d25