Fix test_jit_cuda_archflags on machine with more than one arch (#50405)
Summary:
This fixes the following flaky test on machine with gpus of different arch:
```
_________________________________________________________________________________________________________________ TestCppExtensionJIT.test_jit_cuda_archflags __________________________________________________________________________________________________________________
self = <test_cpp_extensions_jit.TestCppExtensionJIT testMethod=test_jit_cuda_archflags>
unittest.skipIf(not TEST_CUDA, "CUDA not found")
unittest.skipIf(TEST_ROCM, "disabled on rocm")
def test_jit_cuda_archflags(self):
# Test a number of combinations:
# - the default for the machine we're testing on
# - Separators, can be ';' (most common) or ' '
# - Architecture names
# - With/without '+PTX'
capability = torch.cuda.get_device_capability()
# expected values is length-2 tuple: (list of ELF, list of PTX)
# note: there should not be more than one PTX value
archflags = {
'': (['{}{}'.format(capability[0], capability[1])], None),
"Maxwell+Tegra;6.1": (['53', '61'], None),
"Pascal 3.5": (['35', '60', '61'], None),
"Volta": (['70'], ['70']),
}
if int(torch.version.cuda.split('.')[0]) >= 10:
# CUDA 9 only supports compute capability <= 7.2
archflags["7.5+PTX"] = (['75'], ['75'])
archflags["5.0;6.0+PTX;7.0;7.5"] = (['50', '60', '70', '75'], ['60'])
for flags, expected in archflags.items():
> self._run_jit_cuda_archflags(flags, expected)
test_cpp_extensions_jit.py:198:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_cpp_extensions_jit.py:158: in _run_jit_cuda_archflags
_check_cuobjdump_output(expected[0])
test_cpp_extensions_jit.py:134: in _check_cuobjdump_output
self.assertEqual(actual_arches, expected_arches,
../../.local/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py:1211: in assertEqual
super().assertEqual(len(x), len(y), msg=self._get_assert_msg(msg, debug_msg=debug_msg))
E AssertionError: 2 != 1 : Attempted to compare the lengths of [iterable] types: Expected: 2; Actual: 1.
E Flags: , Actual: ['sm_75', 'sm_86'], Expected: ['sm_86']
E Stderr:
E Output: ELF file 1: cudaext_archflags.1.sm_75.cubin
E ELF file 2: cudaext_archflags.2.sm_86.cubin
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50405
Reviewed By: albanD
Differential Revision: D25920200
Pulled By: mrshenli
fbshipit-source-id: 1042a984142108f954a283407334d39e3ec328ce