pytorch
5db9f911 - [pt][group_fusion] fix shape guarding in fusion candidate search (#111174)

Commit
1 year ago
[pt][group_fusion] fix shape guarding in fusion candidate search (#111174) Summary: without the `all` in the fix ``` node.kwargs.get("beta", 1.0) == 1.0 node.kwargs.get("alpha", 1.0) == 1.0 and len(input_shape) == 2 and len(weight_shape) == 2 and all(x % 2 == 0 for x in input_shape + weight_shape) and shape <= MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR # <----- HERE for shape in input_shape + weight_shape ``` this statement defaults to a generator object which means it will always be true. One of the issues is that the shapes could be an odd number which forces gmm to load element-by-element rather than vectorized load. In VDDv3 torchbench example(posted in test plan), you can see there is a 37ms GMM call which swamps any gain from fusion. Overall this change makes the GMM fusion 24% faster Differential Revision: D48696572 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111174 Approved by: https://github.com/davidberard98
Author
Committer
Parents
Loading