torch/jit/_trace.py in compare_outputs(original, reference, match_wha… (#84850)
Fixes #83533
### Bug:
```
/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py in _check_trace(check_inputs, func, traced_func, check_tolerance, strict, force_outplace, is_trace_module, _module_class)
525 traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
526 fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
--> 527 if compare_outputs(traced_outs, fn_outs, "Python function"):
528 check_outs = run_mod_and_filter_tensor_outputs(
529 check_mod_func, inputs, "repeated trace"
/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py in compare_outputs(original, reference, match_what)
500 else:
501 torch.testing.assert_close(
--> 502 orig.double(),
503 ref.double(),
504 rtol=check_tolerance,
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
```
### Fix:
```
if orig.is_mps or ref.is_mps:
torch.testing.assert_close(
orig.float(),
ref.float(),
rtol=check_tolerance,
atol=default_tolerances(orig, ref)[1],
equal_nan=True,
)
else:
torch.testing.assert_close(
orig.double(),
ref.double(),
rtol=check_tolerance,
atol=default_tolerances(orig, ref)[1],
equal_nan=True,
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84850
Approved by: https://github.com/davidberard98