Build shape expressions and remove outputs that are only used by `aten::size`s (#45080)
Summary:
Currently, TE materializes all intermediate results even if they are only used for computing their shapes. This diff ports the approach the OF (Old Fuser) took to deal with this issue. Namely, given the structure of a fusion group we infer all the sizes outside a fusion group based on fusion group's inputs.
A simple example would be:
```
def test_fuse(a, b):
c = a + b
d = c + b
return d
```
Here we don't need to cache `c` as computing a gradient for `b` in `d = c + b` doesn't need it. We do need to compute sizes for all arguments here in case broadcasts happen.
Without this optimization, TE would need to materialize `c` so we can get its size
```
[DUMP profiling_graph_executor_impl.cpp:499] Optimized Graph:
[DUMP profiling_graph_executor_impl.cpp:499] graph(%a.1 : Tensor,
[DUMP profiling_graph_executor_impl.cpp:499] %b.1 : Tensor):
[DUMP profiling_graph_executor_impl.cpp:499] %11 : Tensor = prim::DifferentiableGraph_0(%b.1, %a.1)
[DUMP profiling_graph_executor_impl.cpp:499] return (%11)
[DUMP profiling_graph_executor_impl.cpp:499] with prim::DifferentiableGraph_0 = graph(%11 : Tensor,
[DUMP profiling_graph_executor_impl.cpp:499] %13 : Tensor):
[DUMP profiling_graph_executor_impl.cpp:499] %59 : int[] = aten::size(%13) # <string>:3:44
[DUMP profiling_graph_executor_impl.cpp:499] %62 : int[] = aten::size(%11) # <string>:3:93
[DUMP profiling_graph_executor_impl.cpp:499] %83 : Double(1:1, requires_grad=0, device=cuda:0), %84 : Double(1:1, requires_grad=0, device=cuda:0), %85 : bool = prim::TypeCheck(%11, %13)
[DUMP profiling_graph_executor_impl.cpp:499] %86 : Tensor, %87 : Tensor = prim::If(%85)
[DUMP profiling_graph_executor_impl.cpp:499] block0():
[DUMP profiling_graph_executor_impl.cpp:499] %d.4 : Double(1:1, requires_grad=0, device=cuda:0), %c.4 : Double(1:1, requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%83, %84)
[DUMP profiling_graph_executor_impl.cpp:499] -> (%d.4, %c.4)
[DUMP profiling_graph_executor_impl.cpp:499] block1():
[DUMP profiling_graph_executor_impl.cpp:499] %94 : Function = prim::Constant[name="fallback_function", fallback=1]()
[DUMP profiling_graph_executor_impl.cpp:499] %95 : (Tensor, Tensor) = prim::CallFunction(%94, %11, %13)
[DUMP profiling_graph_executor_impl.cpp:499] %96 : Tensor, %97 : Tensor = prim::TupleUnpack(%95)
[DUMP profiling_graph_executor_impl.cpp:499] -> (%96, %97)
[DUMP profiling_graph_executor_impl.cpp:499] %60 : int[] = aten::size(%87) # <string>:3:55
[DUMP profiling_graph_executor_impl.cpp:499] %61 : int[]? = aten::_size_if_not_equal(%59, %60) # <string>:3:19
[DUMP profiling_graph_executor_impl.cpp:499] %64 : int[]? = aten::_size_if_not_equal(%62, %60) # <string>:3:68
[DUMP profiling_graph_executor_impl.cpp:499] %67 : int[] = aten::size(%86) # <string>:3:55
[DUMP profiling_graph_executor_impl.cpp:499] %68 : int[]? = aten::_size_if_not_equal(%60, %67) # <string>:3:19
[DUMP profiling_graph_executor_impl.cpp:499] %71 : int[]? = aten::_size_if_not_equal(%62, %67) # <string>:3:68
[DUMP profiling_graph_executor_impl.cpp:499] return (%86, %61, %64, %68, %71)
[DUMP profiling_graph_executor_impl.cpp:499] with prim::TensorExprGroup_0 = graph(%1 : Double(1:1, requires_grad=0, device=cuda:0),
[DUMP profiling_graph_executor_impl.cpp:499] %4 : Double(1:1, requires_grad=0, device=cuda:0)):
[DUMP profiling_graph_executor_impl.cpp:499] %5 : int = prim::Constant[value=1]()
[DUMP profiling_graph_executor_impl.cpp:499] %c.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%4, %1, %5) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2872:16
[DUMP profiling_graph_executor_impl.cpp:499] %2 : int = prim::Constant[value=1]()
[DUMP profiling_graph_executor_impl.cpp:499] %d.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%c.3, %1, %2) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2873:16
[DUMP profiling_graph_executor_impl.cpp:499] return (%d.3, %c.3)
```
With this optimization we use `prim::BroadcastSizes` to compute the size of `c`. No need to materialize it.
```
[DUMP profiling_graph_executor_impl.cpp:499] Optimized Graph:
[DUMP profiling_graph_executor_impl.cpp:499] graph(%a.1 : Tensor,
[DUMP profiling_graph_executor_impl.cpp:499] %b.1 : Tensor):
[DUMP profiling_graph_executor_impl.cpp:499] %11 : Tensor = prim::DifferentiableGraph_0(%b.1, %a.1)
[DUMP profiling_graph_executor_impl.cpp:499] return (%11)
[DUMP profiling_graph_executor_impl.cpp:499] with prim::DifferentiableGraph_0 = graph(%11 : Tensor,
[DUMP profiling_graph_executor_impl.cpp:499] %13 : Tensor):
[DUMP profiling_graph_executor_impl.cpp:499] %59 : int[] = aten::size(%13) # <string>:3:44
[DUMP profiling_graph_executor_impl.cpp:499] %62 : int[] = aten::size(%11) # <string>:3:93
[DUMP profiling_graph_executor_impl.cpp:499] %88 : Double(1:1, requires_grad=0, device=cuda:0), %89 : Double(1:1, requires_grad=0, device=cuda:0), %90 : bool = prim::TypeCheck(%11, %13)
[DUMP profiling_graph_executor_impl.cpp:499] %91 : Tensor = prim::If(%90)
[DUMP profiling_graph_executor_impl.cpp:499] block0():
[DUMP profiling_graph_executor_impl.cpp:499] %d.4 : Double(1:1, requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%88, %89)
[DUMP profiling_graph_executor_impl.cpp:499] -> (%d.4)
[DUMP profiling_graph_executor_impl.cpp:499] block1():
[DUMP profiling_graph_executor_impl.cpp:499] %97 : Function = prim::Constant[name="fallback_function", fallback=1]()
[DUMP profiling_graph_executor_impl.cpp:499] %98 : (Tensor) = prim::CallFunction(%97, %11, %13)
[DUMP profiling_graph_executor_impl.cpp:499] %99 : Tensor = prim::TupleUnpack(%98)
[DUMP profiling_graph_executor_impl.cpp:499] -> (%99)
[DUMP profiling_graph_executor_impl.cpp:499] %85 : int[] = aten::size(%91)
[DUMP profiling_graph_executor_impl.cpp:499] %86 : int[] = prim::BroadcastSizes(%59, %62)
[DUMP profiling_graph_executor_impl.cpp:499] %61 : int[]? = aten::_size_if_not_equal(%59, %86) # <string>:3:19
[DUMP profiling_graph_executor_impl.cpp:499] %64 : int[]? = aten::_size_if_not_equal(%62, %86) # <string>:3:68
[DUMP profiling_graph_executor_impl.cpp:499] %68 : int[]? = aten::_size_if_not_equal(%86, %85) # <string>:3:19
[DUMP profiling_graph_executor_impl.cpp:499] %71 : int[]? = aten::_size_if_not_equal(%62, %85) # <string>:3:68
[DUMP profiling_graph_executor_impl.cpp:499] return (%91, %61, %64, %68, %71)
[DUMP profiling_graph_executor_impl.cpp:499] with prim::TensorExprGroup_0 = graph(%1 : Double(1:1, requires_grad=0, device=cuda:0),
[DUMP profiling_graph_executor_impl.cpp:499] %4 : Double(1:1, requires_grad=0, device=cuda:0)):
[DUMP profiling_graph_executor_impl.cpp:499] %5 : int = prim::Constant[value=1]()
[DUMP profiling_graph_executor_impl.cpp:499] %c.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%4, %1, %5) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2872:16
[DUMP profiling_graph_executor_impl.cpp:499] %2 : int = prim::Constant[value=1]()
[DUMP profiling_graph_executor_impl.cpp:499] %d.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%c.3, %1, %2) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2873:16
[DUMP profiling_graph_executor_impl.cpp:499] return (%d.3)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45080
Reviewed By: bertmaher
Differential Revision: D23856410
Pulled By: Krovatkin
fbshipit-source-id: 2956286eb03a4894a5baa151c35e6092466322b1