pytorch
126ea1cc - relax type equality constraint for scalars (#57532)

Commit
3 years ago
relax type equality constraint for scalars (#57532) Summary: Currently we require type equality for `torch.testing.assert_(equal|close)`: https://github.com/pytorch/pytorch/blob/3db45bcb915a32ffa60378205fd9d96c45d7113f/torch/testing/_asserts.py#L509-L513 That means `assert_equal(1, 1.0)` will correctly fail. Although the type of a scalar is similiar to a dtype of a tensor, `assert_equal(1, 1.0, check_dtype=False)` will also fail while `assert_equal(torch.as_tensor(1), torch.as_tensor(1.0), check_dtype=False)` will pass. To make the interface more consistent, this PR relaxes the type equality constraint, by disabling it in case both inputs are scalars. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57532 Reviewed By: ngimel Differential Revision: D28242428 Pulled By: mruberry fbshipit-source-id: b643c77f48b64fc2c8a43925120d2b634ec336b5
Author
Parents
Loading