[sr][pyper] add fusion broadcast_concat_batch_matmul_batch_gather (#76839)
Summary:
Fuse broadcast_stack -> transpose -> matmul -> flatten (sometimes duplicated with a second extraneous no-op call) -> index_select
I added broadcast support to the fused op, although I haven't seen any examples where the inputs actually need broadcasting, so I didn't focus on improving perf on that part.
Test Plan:
Saves ~0.15ms (~22%) from adfinder_story_post_ad_session_exit_model's remote_other net, or ~10% of the models' overall execution time. With this change, this model now overall runs ~4% faster in static runtime than the C2 baseline.
Before:
I0503 09:19:27.760602 2477739 PyTorchPredictorBenchLib.cpp:305] PyTorch run finished. Milliseconds per iter: 0.799885. Iters per second: 1250.18
0.201173 ms. 27.4953%. static_runtime::flatten_copy (2 nodes, out variant)
0.108989 ms. 14.8962%. aten::index_select (1 nodes, out variant)
0.105632 ms. 14.4372%. aten::matmul (1 nodes, out variant)
0.0637207 ms. 8.70904%. fb::broadcast_stack (1 nodes)
0.000658989 ms. 0.0900675%. aten::transpose (1 nodes, native)
After:
I0504 10:59:36.388628 1042000 PyTorchPredictorBenchLib.cpp:305] PyTorch run finished. Milliseconds per iter: 0.512515. Iters per second: 1951.16
0.231542 ms. 46.6407%. fb::broadcast_concat_batch_matmul_batch_gather (1 nodes, out variant)
Differential Revision: D36139538
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76839
Approved by: https://github.com/mikeiovine