pytorch
993628c7 - Build shape expressions and remove outputs that are only used by `aten::size`s (#45080)

Commit
4 years ago
Build shape expressions and remove outputs that are only used by `aten::size`s (#45080) Summary: Currently, TE materializes all intermediate results even if they are only used for computing their shapes. This diff ports the approach the OF (Old Fuser) took to deal with this issue. Namely, given the structure of a fusion group we infer all the sizes outside a fusion group based on fusion group's inputs. A simple example would be: ``` def test_fuse(a, b): c = a + b d = c + b return d ``` Here we don't need to cache `c` as computing a gradient for `b` in `d = c + b` doesn't need it. We do need to compute sizes for all arguments here in case broadcasts happen. Without this optimization, TE would need to materialize `c` so we can get its size ``` [DUMP profiling_graph_executor_impl.cpp:499] Optimized Graph: [DUMP profiling_graph_executor_impl.cpp:499] graph(%a.1 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %b.1 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %11 : Tensor = prim::DifferentiableGraph_0(%b.1, %a.1) [DUMP profiling_graph_executor_impl.cpp:499] return (%11) [DUMP profiling_graph_executor_impl.cpp:499] with prim::DifferentiableGraph_0 = graph(%11 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %13 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %59 : int[] = aten::size(%13) # <string>:3:44 [DUMP profiling_graph_executor_impl.cpp:499] %62 : int[] = aten::size(%11) # <string>:3:93 [DUMP profiling_graph_executor_impl.cpp:499] %83 : Double(1:1, requires_grad=0, device=cuda:0), %84 : Double(1:1, requires_grad=0, device=cuda:0), %85 : bool = prim::TypeCheck(%11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %86 : Tensor, %87 : Tensor = prim::If(%85) [DUMP profiling_graph_executor_impl.cpp:499] block0(): [DUMP profiling_graph_executor_impl.cpp:499] %d.4 : Double(1:1, requires_grad=0, device=cuda:0), %c.4 : Double(1:1, requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%83, %84) [DUMP profiling_graph_executor_impl.cpp:499] -> (%d.4, %c.4) [DUMP profiling_graph_executor_impl.cpp:499] block1(): [DUMP profiling_graph_executor_impl.cpp:499] %94 : Function = prim::Constant[name="fallback_function", fallback=1]() [DUMP profiling_graph_executor_impl.cpp:499] %95 : (Tensor, Tensor) = prim::CallFunction(%94, %11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %96 : Tensor, %97 : Tensor = prim::TupleUnpack(%95) [DUMP profiling_graph_executor_impl.cpp:499] -> (%96, %97) [DUMP profiling_graph_executor_impl.cpp:499] %60 : int[] = aten::size(%87) # <string>:3:55 [DUMP profiling_graph_executor_impl.cpp:499] %61 : int[]? = aten::_size_if_not_equal(%59, %60) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %64 : int[]? = aten::_size_if_not_equal(%62, %60) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] %67 : int[] = aten::size(%86) # <string>:3:55 [DUMP profiling_graph_executor_impl.cpp:499] %68 : int[]? = aten::_size_if_not_equal(%60, %67) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %71 : int[]? = aten::_size_if_not_equal(%62, %67) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] return (%86, %61, %64, %68, %71) [DUMP profiling_graph_executor_impl.cpp:499] with prim::TensorExprGroup_0 = graph(%1 : Double(1:1, requires_grad=0, device=cuda:0), [DUMP profiling_graph_executor_impl.cpp:499] %4 : Double(1:1, requires_grad=0, device=cuda:0)): [DUMP profiling_graph_executor_impl.cpp:499] %5 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %c.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%4, %1, %5) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2872:16 [DUMP profiling_graph_executor_impl.cpp:499] %2 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %d.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%c.3, %1, %2) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2873:16 [DUMP profiling_graph_executor_impl.cpp:499] return (%d.3, %c.3) ``` With this optimization we use `prim::BroadcastSizes` to compute the size of `c`. No need to materialize it. ``` [DUMP profiling_graph_executor_impl.cpp:499] Optimized Graph: [DUMP profiling_graph_executor_impl.cpp:499] graph(%a.1 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %b.1 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %11 : Tensor = prim::DifferentiableGraph_0(%b.1, %a.1) [DUMP profiling_graph_executor_impl.cpp:499] return (%11) [DUMP profiling_graph_executor_impl.cpp:499] with prim::DifferentiableGraph_0 = graph(%11 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %13 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %59 : int[] = aten::size(%13) # <string>:3:44 [DUMP profiling_graph_executor_impl.cpp:499] %62 : int[] = aten::size(%11) # <string>:3:93 [DUMP profiling_graph_executor_impl.cpp:499] %88 : Double(1:1, requires_grad=0, device=cuda:0), %89 : Double(1:1, requires_grad=0, device=cuda:0), %90 : bool = prim::TypeCheck(%11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %91 : Tensor = prim::If(%90) [DUMP profiling_graph_executor_impl.cpp:499] block0(): [DUMP profiling_graph_executor_impl.cpp:499] %d.4 : Double(1:1, requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%88, %89) [DUMP profiling_graph_executor_impl.cpp:499] -> (%d.4) [DUMP profiling_graph_executor_impl.cpp:499] block1(): [DUMP profiling_graph_executor_impl.cpp:499] %97 : Function = prim::Constant[name="fallback_function", fallback=1]() [DUMP profiling_graph_executor_impl.cpp:499] %98 : (Tensor) = prim::CallFunction(%97, %11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %99 : Tensor = prim::TupleUnpack(%98) [DUMP profiling_graph_executor_impl.cpp:499] -> (%99) [DUMP profiling_graph_executor_impl.cpp:499] %85 : int[] = aten::size(%91) [DUMP profiling_graph_executor_impl.cpp:499] %86 : int[] = prim::BroadcastSizes(%59, %62) [DUMP profiling_graph_executor_impl.cpp:499] %61 : int[]? = aten::_size_if_not_equal(%59, %86) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %64 : int[]? = aten::_size_if_not_equal(%62, %86) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] %68 : int[]? = aten::_size_if_not_equal(%86, %85) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %71 : int[]? = aten::_size_if_not_equal(%62, %85) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] return (%91, %61, %64, %68, %71) [DUMP profiling_graph_executor_impl.cpp:499] with prim::TensorExprGroup_0 = graph(%1 : Double(1:1, requires_grad=0, device=cuda:0), [DUMP profiling_graph_executor_impl.cpp:499] %4 : Double(1:1, requires_grad=0, device=cuda:0)): [DUMP profiling_graph_executor_impl.cpp:499] %5 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %c.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%4, %1, %5) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2872:16 [DUMP profiling_graph_executor_impl.cpp:499] %2 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %d.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%c.3, %1, %2) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2873:16 [DUMP profiling_graph_executor_impl.cpp:499] return (%d.3) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/45080 Reviewed By: bertmaher Differential Revision: D23856410 Pulled By: Krovatkin fbshipit-source-id: 2956286eb03a4894a5baa151c35e6092466322b1
Author
Parents
Loading