pytorch
f10aab03 - [sparse] Fix semi-structured sparse shape mismatch bug (#110420)

Commit
1 year ago
[sparse] Fix semi-structured sparse shape mismatch bug (#110420) Summary: Currently, PyTorch incorrectly calculates the size of the returned matrix when we pass a non-contiguous batched (>2d) input to the semi-structured sparse subclass. This is most common in MLP layers, where we have 2 linear layers back to back. This will lead to an error like the following: ``` RuntimeError: shape '[20, 64, 64, 3072]' is invalid for input of size 62914560 ``` Where the size of the sparse matmul result is off because we infer the output shape with the wrong tensor shape. This happens because of a bug where we did not update the subclass tensor shape when doing transpose. For semi-structured sparsity, transposing is a no-op where we just set the boolean flag, but we forgot to also update the tensor shape. Note that this error goes away in inference mode, since we avoid decomposing the aten.linear op and handle shape folding ourselves, which changes the execution path. An alternative way to fix this issue is to set TORCH_FLATTEN_LINEAR_3D=True, which will also fix this error. Test Plan: ``` python test/test_sparse_semi_structured.py -k test_mlp ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/110420 Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
Author
Committer
Parents
Loading