[JIT] Specialize AutogradZero: merge AutogradAnyNonZero and Not(AutogradAnyNonZero) checks into one. (#44987)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44987
This PR introduces new `prim::AutogradAllZero` and
`prim::AutogradAllNonZero` ops that are used for a batch check for
multiple tensors. The specialize-autogradzero pass now generates one
check for all expected-to-be-undefined tensors, one check for all
expected-to-be-defined tensors, and a bunch of checks for size
parameters passed to `grad_sum_to_size` (this probably could be cleaned
up somehow as well in future).
An example of what we generated before this change:
```
%1626 : bool = prim::AutogradAnyNonZero(%0)
%1627 : bool = prim::AutogradAnyNonZero(%2)
%1628 : bool = aten::__not__(%1627)
%1629 : bool = prim::AutogradAnyNonZero(%3)
%1630 : bool = aten::__not__(%1629)
%1631 : bool = prim::AutogradAnyNonZero(%4)
%1632 : bool = aten::__not__(%1631)
%1633 : bool = prim::AutogradAnyNonZero(%5)
%1634 : bool = aten::__not__(%1633)
%1635 : bool = prim::AutogradAnyNonZero(%6)
%1636 : bool = aten::__not__(%1635)
%1637 : bool = prim::AutogradAnyNonZero(%7)
%1638 : bool = aten::__not__(%1637)
%1639 : bool = prim::AutogradAnyNonZero(%8)
%1640 : bool = aten::__not__(%1639)
%1641 : bool = prim::AutogradAnyNonZero(%9)
%1642 : bool = aten::__not__(%1641)
%1643 : bool = prim::AutogradAnyNonZero(%10)
%1644 : bool = aten::__not__(%1643)
%1645 : bool = prim::AutogradAnyNonZero(%11)
%1646 : bool = aten::__not__(%1645)
%1647 : bool = prim::AutogradAnyNonZero(%12)
%1648 : bool = aten::__not__(%1647)
%1649 : bool = prim::AutogradAnyNonZero(%13)
%1650 : bool = aten::__not__(%1649)
%1651 : bool = prim::AutogradAnyNonZero(%14)
%1652 : bool = aten::__not__(%1651)
%1653 : bool = prim::AutogradAnyNonZero(%15)
%1654 : bool = aten::__not__(%1653)
%1655 : bool = prim::AutogradAnyNonZero(%16)
%1656 : bool = aten::__not__(%1655)
%1657 : bool = prim::AutogradAnyNonZero(%17)
%1658 : bool = prim::AutogradAnyNonZero(%18)
%1659 : bool = prim::AutogradAnyNonZero(%19)
%1660 : bool = prim::AutogradAnyNonZero(%20)
%1661 : bool = aten::__is__(%self_size.16, %1625)
%1662 : bool = aten::__is__(%other_size.16, %1625)
%1663 : bool = aten::__is__(%self_size.14, %1625)
%1664 : bool = aten::__is__(%self_size.12, %1625)
%1665 : bool = prim::AutogradAnyNonZero(%ingate.7)
%1666 : bool = prim::AutogradAnyNonZero(%forgetgate.7)
%1667 : bool = prim::AutogradAnyNonZero(%cellgate.7)
%1668 : bool = prim::AutogradAnyNonZero(%30)
%1669 : bool = prim::AutogradAnyNonZero(%31)
%1670 : bool = aten::__is__(%self_size.10, %1625)
%1671 : bool = aten::__is__(%other_size.10, %1625)
%1672 : bool = prim::AutogradAnyNonZero(%34)
%1673 : bool = prim::AutogradAnyNonZero(%35)
%1674 : bool = aten::__is__(%self_size.8, %1625)
%1675 : bool = aten::__is__(%other_size.8, %1625)
%1676 : bool = aten::__is__(%self_size.6, %1625)
%1677 : bool = aten::__is__(%other_size.6, %1625)
%1678 : bool = prim::AutogradAnyNonZero(%outgate.7)
%1679 : bool = prim::AutogradAnyNonZero(%41)
%1680 : bool = prim::AutogradAnyNonZero(%42)
%1681 : bool = prim::AutogradAnyNonZero(%43)
%1682 : bool = aten::__is__(%self_size.4, %1625)
%1683 : bool = aten::__is__(%other_size.4, %1625)
%1684 : bool[] = prim::ListConstruct(%1626, %1628, %1630, %1632, %1634, %1636, %1638, %1640, %1642, %1644, %1646, %1648, %1650, %1652, %1654, %1656, %1657, %1658, %1659, %1660, %1661, %1662, %1663, %1664, %1665, %1666, %1667, %1668, %1669, %1670, %1671, %1672, %1673, %1674, %1675, %1676, %1677, %1678, %1679, %1680, %1681, %1682, %1683)
%1685 : bool = aten::all(%1684)
```
Same example after this change:
```
%1625 : None = prim::Constant()
%1626 : bool = aten::__is__(%self_size.16, %1625)
%1627 : bool = aten::__is__(%other_size.16, %1625)
%1628 : bool = aten::__is__(%self_size.14, %1625)
%1629 : bool = aten::__is__(%self_size.12, %1625)
%1630 : bool = aten::__is__(%self_size.10, %1625)
%1631 : bool = aten::__is__(%other_size.10, %1625)
%1632 : bool = aten::__is__(%self_size.8, %1625)
%1633 : bool = aten::__is__(%other_size.8, %1625)
%1634 : bool = aten::__is__(%self_size.6, %1625)
%1635 : bool = aten::__is__(%other_size.6, %1625)
%1636 : bool = aten::__is__(%self_size.4, %1625)
%1637 : bool = aten::__is__(%other_size.4, %1625)
%1638 : bool = prim::AutogradAllNonZero(%0, %17, %18, %19, %20, %ingate.7, %forgetgate.7, %cellgate.7, %30, %31, %34, %35, %outgate.7, %41, %42, %43)
%1639 : bool = prim::AutogradAllZero(%2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16)
%1640 : bool[] = prim::ListConstruct(%1626, %1627, %1628, %1629, %1630, %1631, %1632, %1633, %1634, %1635, %1636, %1637, %1638, %1639)
%1641 : bool = aten::all(%1640)
```
My performance measurements showed some changes, but I don't really
trust them and think that they are probably just a noise. Below are
tables with min-aggregation over 10 runs:
FastRNN models:
| name | base time (s) | diff time (s) | % change |
| :--- | ---: | ---: | ---: |
| lstm[aten]:bwd | 30.059927 | 29.834089 | -0.8% |
| lstm[aten]:fwd | 25.673708 | 25.700039 | 0.1% |
| lstm[cudnn]:bwd | 17.866232 | 17.893120 | 0.2% |
| lstm[cudnn]:fwd | 11.418444 | 11.408514 | -0.1% |
| lstm[jit]:bwd | 27.127205 | 27.141029 | 0.1% |
| lstm[jit]:fwd | 17.018047 | 16.975451 | -0.3% |
| lstm[jit_multilayer]:bwd | 27.502396 | 27.365149 | -0.5% |
| lstm[jit_multilayer]:fwd | 16.918591 | 16.917767 | -0.0% |
| lstm[jit_premul]:bwd | 22.281199 | 22.215082 | -0.3% |
| lstm[jit_premul]:fwd | 14.848708 | 14.896231 | 0.3% |
| lstm[jit_premul_bias]:bwd | 20.761206 | 21.170969 | 2.0% |
| lstm[jit_premul_bias]:fwd | 15.013515 | 15.037978 | 0.2% |
| lstm[jit_simple]:bwd | 26.715771 | 26.697786 | -0.1% |
| lstm[jit_simple]:fwd | 16.675898 | 16.545893 | -0.8% |
| lstm[py]:bwd | 56.327065 | 54.731030 | -2.8% |
| lstm[py]:fwd | 39.876324 | 39.230572 | -1.6% |
Torch Hub models:
| name | base time (s) | diff time (s) | % change |
| :--- | ---: | ---: | ---: |
| test_eval[BERT_pytorch-cuda-jit] | 0.111706 | 0.106604 | -4.6% |
| test_eval[LearningToPaint-cuda-jit] | 0.002841 | 0.002801 | -1.4% |
| test_eval[Super_SloMo-cuda-jit] | 0.384869 | 0.384737 | -0.0% |
| test_eval[attension_is_all_you_nee...-cuda-jit] | 0.123857 | 0.123923 | 0.1% |
| test_eval[demucs-cuda-jit] | 0.077270 | 0.076878 | -0.5% |
| test_eval[fastNLP-cuda-jit] | 0.000255 | 0.000249 | -2.3% |
| test_eval[moco-cuda-jit] | 0.426472 | 0.427380 | 0.2% |
| test_eval[pytorch_CycleGAN_and_pix...-cuda-jit] | 0.026483 | 0.026423 | -0.2% |
| test_eval[pytorch_mobilenet_v3-cuda-jit] | 0.036202 | 0.035853 | -1.0% |
| test_eval[pytorch_struct-cuda-jit] | 0.001439 | 0.001495 | 3.9% |
| test_train[BERT_pytorch-cuda-jit] | 0.247236 | 0.247188 | -0.0% |
| test_train[Background_Matting-cuda-jit] | 3.536659 | 3.581864 | 1.3% |
| test_train[LearningToPaint-cuda-jit] | 0.015341 | 0.015331 | -0.1% |
| test_train[Super_SloMo-cuda-jit] | 1.018626 | 1.019098 | 0.0% |
| test_train[attension_is_all_you_nee...-cuda-jit] | 0.446314 | 0.444893 | -0.3% |
| test_train[demucs-cuda-jit] | 0.169647 | 0.169846 | 0.1% |
| test_train[fastNLP-cuda-jit] | 0.001990 | 0.001978 | -0.6% |
| test_train[moco-cuda-jit] | 0.855323 | 0.856974 | 0.2% |
| test_train[pytorch_mobilenet_v3-cuda-jit] | 0.497723 | 0.485416 | -2.5% |
| test_train[pytorch_struct-cuda-jit] | 0.309692 | 0.308792 | -0.3% |
Differential Revision: D23794659
Test Plan: Imported from OSS
Reviewed By: bertmaher
Pulled By: ZolotukhinM
fbshipit-source-id: 859b68868ef839c5c6cbc7021879ee22d3144ea8
Author
Mikhail Zolotukhin