Avoid testing device in cdist when called in a "Math" context (#54708)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54708
cdist advertises itself as Math but actually it error checks that the inputs
are CPU/CUDA in cdist_impl, which is invoked from a composite context in some
situations. I worked around this by ensuring that when cdist_impl was called in
this way, we DON'T do the device checks, but the entire function is a little
janky and I filed an issue about it at #54096
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D27338813
Pulled By: ezyang
fbshipit-source-id: 1202b02c58584a33dc32a5270e59e5f0af6398c5