enable warnings on cuda synchronization (#62092)
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
Author
Natalia Gimelshein