pytorch
f5a9c36d - [SR] Eliminate extra permute ops before `aten::sum` (#74481)

Commit
2 years ago
[SR] Eliminate extra permute ops before `aten::sum` (#74481) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74481 This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` ghstack-source-id: 152003888 Reviewed By: navahgar Differential Revision: D34992319 fbshipit-source-id: 0baf493708ee2180c899814a954d220d88ba1d4f (cherry picked from commit 797b6beb26325c56012e406e14fe211c0b5d744d)
Author
Mike Iovine
Committer
Parents
Loading