Turn on strict dtype checking for test_torch.py (#33825)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33825
Partially addresses #20376
I do this by overriding assertEqual in classes that opt into
this. This means I have to fix #33821. The fix is a little
unsatisfactory as idiomatic Python 2 super() calls don't work
(since the class is no longer in scope); hopefully this will just
work when we go to Python 3.
General approach taken:
- A lot of dtype mismatches are because we specified tensor constants
that infer to some dtype, but the actual dtype needed is something else.
Those are easy, just annotate the tensor() constructor (often a legacy
Tensor/FloatTensor call) with dtype
- There are a few cases where the promotion rules are nontrivial. Some of them
I just typed out the expected promotion rules manually (based on trial
and error)
- There are some more complex cases; if it gets too hairy I just
set exact_dtype=False and nope the fuck out
I don't have time to do it for all the other classes. But the setup
should work if people just incrementally add the overrides to classes,
and then eventually flip the default.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D20125791
Pulled By: ezyang
fbshipit-source-id: 389c2d1efbd93172af02f13e38ac5e92fe730c57