[JIT] Improve BatchMM mutability handling (#65097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65097
Previously, BatchMM would skip any block containing any mutable
operators. Now it will avoid batching any operation whose inputs or
outputs are ever mutated. Specifically: consider a tree of ADD, T,
and MM nodes rooted at an ADD node. If any input or output to any
node in the tree is ever mutated, then the entire tree will be ignored
by BatchMM.
Test Plan: python test/test_jit.py TestBatchMM
Reviewed By: eellison
Differential Revision: D30973515
Pulled By: davidberard98
fbshipit-source-id: 9d836faa1ef0c9e3fefe0ffc0bd265f275471f48