pytorch
8a053e38 - remove special casing for sparse CSR shape comparison

Commit
3 years ago
remove special casing for sparse CSR shape comparison Fixes https://github.com/pytorch/pytorch/pull/74264#discussion_r845780445. The shape check works with or without the extras added in #74264. ```py >>> a = torch.rand(2, 2).to_sparse_csr() >>> b = torch.rand(2, 3).to_sparse_csr() >>> torch.testing.assert_close(a, b) AssertionError: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 3]). ``` Tensor comparison is split into two parts: 1. Attribute comparison. 2. Value comparison. https://github.com/pytorch/pytorch/blob/bcf6974c207ac0339bfb8bdfdb0b0ec348f7a22f/torch/testing/_comparison.py#L611-L616 The attribute comparison happens in https://github.com/pytorch/pytorch/blob/bcf6974c207ac0339bfb8bdfdb0b0ec348f7a22f/torch/testing/_comparison.py#L618 The check for the matching shape https://github.com/pytorch/pytorch/blob/bcf6974c207ac0339bfb8bdfdb0b0ec348f7a22f/torch/testing/_comparison.py#L647-L648 is one of the few checks that cannot be disabled through keyword arguments. Thus, there is no need for this check in `_compare_sparse_csr_values` since the comparison will fail before if the shapes mismatch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/75593 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading