enable support for custom error messages in `torch.testing` (#55890)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55890
Proof-of-concept for https://github.com/pytorch/pytorch/pull/55145#issuecomment-817297273
With this the user is able to pass a custom error message to `assert_(equal|close)` which will be used in case the values mismatch. Optionally, a callable can be passed which will be called with mismatch diagnostics and should return an error message:
```python
def make_msg(a, b, info):
return (
f"Argh, we found {info.total_mismatches} mismatches! "
f"That is {info.mismatch_ratio:.1%}!"
)
torch.testing.assert_equal(torch.tensor(1), torch.tensor(2), msg=make_msg)
```
If you imagine `a` and `b` as the outputs of binary ufuncs, the error message could look like this:
```python
def make_msg(input, torch_output, numpy_output, info):
return (
f"For input {input} torch.binary_op() and np.binary_op() do not match: "
f"{torch_output} != {numpy_output}"
)
torch.testing.assert_equal(
torch.binary_op(input),
numpy.binary_op(input),
msg=lambda a, b, info: make_msg(input, a, b, info),
)
```
This should make it much easier for developers to find out what is actually going wrong.
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D27903842
Pulled By: mruberry
fbshipit-source-id: 4c82e3d969e9a621789018018bec6399724cf388