[CUDA] Preload dependent DLLs (#23674)
### Description
Changes:
(1) Pass --cuda_version in packaging pipeline to build wheel command
line so that cuda_version can be saved. Note that cuda_version is also
required for generating extra_require for
https://github.com/microsoft/onnxruntime/pull/23659.
(2) Update steup.py and onnxruntime_validation.py to save cuda version
to capi/build_and_package_info.py.
(3) Add a helper function to preload dependent DLLs (MSVC, CUDA, CUDNN)
in `__init__.py`. First we will try to load DLLs from nvidia site
packages, then try load remaining DLLs with default path settings.
```
import onnxruntime
onnxruntime.preload_dlls()
```
To show loaded DLLs, set `verbose=True`. It is also possible to disable
loading some types of DLLs like:
```
onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
```
#### PyTorch and onnxruntime in Windows
When working with pytorch, onnxruntime will reuse the CUDA and cuDNN
DLLs loaded by pytorch as long as CUDA and cuDNN major versions are
compatible. Preload DLLs actually might cause issues (see example 2 and
3 below) in Windows.
Example 1: onnxruntime and torch can work together easily.
```
>>> import torch
>>> import onnxruntime
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_124.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
>>> session.get_providers()
['CUDAExecutionProvider', 'CPUExecutionProvider']
```
Example 2: Use preload_dlls after `import torch` is not necessary.
Unfortunately, it seems that multiple DLLs of same filename are loaded.
They can be used in parallel but not ideal since more memory is used.
```
>>> import torch
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_124.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
```
Example 3: Use preload_dlls before `import torch` might cause torch
import error in Windows. Later we may provide an option to load DLLs
from torch directory to avoid this issue.
```
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
>>> import torch
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "D:\anaconda3\envs\py310\lib\site-packages\torch\__init__.py", line 137, in <module>
raise err
OSError: [WinError 127] The specified procedure could not be found. Error loading "D:\anaconda3\envs\py310\lib\site-packages\torch\lib\cudnn_adv64_9.dll" or one of its dependencies.
```
#### PyTorch and onnxruntime in Linux
In Linux, since pytorch uses nvidia site packages for CUDA and cuDNN
DLLs. Preload DLLs consistently loads same set of DLLs, and it could
help maintaining.
```
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_graph.so.9
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12
>>> import torch
>>> torch.rand(3, 3).cuda()
tensor([[0.4619, 0.0279, 0.2092],
[0.0416, 0.6782, 0.5889],
[0.9988, 0.9092, 0.7982]], device='cuda:0')
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> session.get_providers()
['CUDAExecutionProvider', 'CPUExecutionProvider']
```
```
>>> import torch
>>> import onnxruntime
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12
```
Without preloading DLLs, onnxruntime will load CUDA and cuDNN DLLs based
on `LD_LIBRARY_PATH`. Torch will reuse the same DLLs loaded by
onnxruntime:
```
>>> import onnxruntime
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41
/cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55
/cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14
/cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/cudnn9.7/lib/libcudnn.so.9.7.0
/cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57
>>> import torch
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41
/cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55
/cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14
/cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/cudnn9.7/lib/libcudnn.so.9.7.0
/cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57
>>> torch.rand(3, 3).cuda()
tensor([[0.2233, 0.9194, 0.8078],
[0.0906, 0.2884, 0.3655],
[0.6249, 0.2904, 0.4568]], device='cuda:0')
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41
/cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55
/cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14
/cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/cudnn9.7/lib/libcudnn.so.9.7.0
/cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57
```
### Motivation and Context
In many reported issues of import onnxruntime failure, the root cause is
dependent DLLs missing or not in path. This change will make it easier
to resolve those issues.
This is based on Jian's PR
https://github.com/microsoft/onnxruntime/pull/22506 with extra change to
load msvc dlls.
https://github.com/microsoft/onnxruntime/pull/23659 can be used to
install CUDA/cuDNN dlls to site packages. Example command line after
next official release 1.21:
```
pip install onnxruntime-gpu[cuda,cudnn]
```
If user installed pytorch in Linux, those DLLs are usually installed
together with torch.