[CUDA] Update preload_dlls to coexist with PyTorch (#23744)
### Description
Update preload_dlls:
(1) Add a parameter `directory` to specify the DLL location.
(2) In Windows, skip loading CUDA/cuDNN dlls when torch for cuda 12.x
has been imported.
(3) In Windows, default search order for CUDA/cuDNN dlls: lib directory
of torch for cuda 12.x in Windows; nvidia site packages; default DLL
loading paths. User can use the directory parameter to change search
order. Use empty string will change search order to `nvidia site
packages; default DLL loading paths`. Use a path if user wants to load
DLLs from a specific location.
(4) Do not load cudnn sub DLLs in Linux.
The benefit of such change is that ORT could work seamlessly with
PyTorch in both Linux and Windows. We also provide option for advanced
users to load CUDA/cuDNN from a location specified by them.
### Examples in Windows
By default, preload_dlls will load CUDA and cuDNN DLLs from PyTorch if
it is compatible:
```
>>> import onnxruntime
>>> onnxruntime.preload_dlls()
>>> onnxruntime.print_debug_info()
onnxruntime-gpu version: 1.21.0
CUDA version used in build: 12.6
platform: Windows-10-10.0.22631-SP0
Python package, version and location:
onnxruntime==1.20.1 at c:\users\abcd\.conda\envs\py310\lib\site-packages\onnxruntime
onnxruntime-gpu==1.21.0 at c:\users\abcd\.conda\envs\py310\lib\site-packages\onnxruntime
WARNING: multiple onnxruntime packages are installed to the same location. Please 'pip uninstall` all above packages, then `pip install` only one of them.
torch==2.6.0+cu126 at c:\users\abcd\.conda\envs\py310\lib\site-packages\torch
nvidia-cuda-runtime-cu12==12.8.57 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
nvidia-cudnn-cu12==9.7.1.26 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
nvidia-cublas-cu12==12.8.3.14 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
nvidia-cufft-cu12==11.3.3.41 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
nvidia-curand-cu12==10.3.7.77 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
nvidia-cuda-nvrtc-cu12==12.6.85 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
nvidia-nvjitlink-cu12==12.8.61 at c:\users\abcd\.conda\envs\py310\lib\site-packages\nvidia
Environment variable:
PATH=c:\users\abcd\.conda\envs\py310;c:\users\abcd\.conda\envs\py310\Library\usr\bin;c:\users\abcd\.conda\envs\py310\Library\bin;c:\users\abcd\.conda\envs\py310\Scripts;c:\users\abcd\.conda\envs\py310\bin;C:\ProgramData\anaconda3\condabin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\libnvvp;C:\windows\system32;
List of loaded DLLs:
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\numpy.libs\msvcp140-263139962577ecda4cd9469ca360a746.dll
c:\users\abcd\.conda\envs\py310\msvcp140.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll
c:\users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll
c:\users\abcd\.conda\envs\py310\vcruntime140_1.dll
c:\users\abcd\.conda\envs\py310\msvcp140_1.dll
c:\users\abcd\.conda\envs\py310\vcruntime140.dll
Device information:
{
"gpu": {
"driver_version": "571.96",
"devices": [
{
"memory_total": 8589934592,
"memory_available": 6032777216,
"name": "NVIDIA GeForce GTX 1080"
}
]
},
"cpu": {
"brand": "Intel(R) Core(TM) i9-10900X CPU @ 3.70GHz",
"cores": 10,
"logical_cores": 20,
"hz": "3696000000,0",
"l2_cache": 10485760,
"flags": "3dnow,3dnowprefetch,abm,acpi,adx,aes,apic,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,avx512vnni,bmi1,bmi2,clflush,clflushopt,clwb,cmov,cx16,cx8,de,dtes64,dts,erms,est,f16c,fma,fpu,fxsr,ht,hypervisor,ia64,invpcid,lahf_lm,mca,mce,mmx,monitor,movbe,mpx,msr,mtrr,osxsave,pae,pat,pbe,pcid,pclmulqdq,pdcm,pge,pni,popcnt,pqe,pqm,pse,pse36,rdrnd,rdseed,sep,serial,smap,smep,ss,sse,sse2,sse4_1,sse4_2,ssse3,tm,tm2,tsc,tscdeadline,vme,x2apic,xsave,xtpr",
"processor": "Intel64 Family 6 Model 85 Stepping 7, GenuineIntel"
},
"memory": {
"total": 68414291968,
"available": 40240791552
}
}
>>> import torch
```
In below example, we set `directory=""`, which prefers nvidia site
package like the following:
```
>>> import onnxruntime
>>> onnxruntime.preload_dlls(directory="")
>>> onnxruntime.print_debug_info()
...
List of loaded DLLs:
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_adv64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_ops64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_engines_precompiled64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_heuristic64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_engines_runtime_compiled64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\numpy.libs\msvcp140-263139962577ecda4cd9469ca360a746.dll
C:\Users\abcd\.conda\envs\py310\msvcp140.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
C:\Users\abcd\.conda\envs\py310\msvcp140_1.dll
C:\Users\abcd\.conda\envs\py310\vcruntime140_1.dll
C:\Users\abcd\.conda\envs\py310\vcruntime140.dll
...
```
In below example, we import torch before preload_dlls. In this case, ORT
skips loading CUDA/cuDNN DLLs, and use the DLLs from torch:
```
>>> import onnxruntime
>>> import torch
>>> onnxruntime.preload_dlls()
Skip loading CUDA and cuDNN DLLs since torch is imported.
>>> onnxruntime.print_debug_info()
...
List of loaded DLLs:
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.alt.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_126.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\numpy.libs\msvcp140-263139962577ecda4cd9469ca360a746.dll
C:\Users\abcd\.conda\envs\py310\msvcp140.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll
C:\Users\abcd\.conda\envs\py310\msvcp140_1.dll
C:\Users\abcd\.conda\envs\py310\vcruntime140_1.dll
C:\Users\abcd\.conda\envs\py310\vcruntime140.dll
...
```
Last example is to load CUDA and cuDNN separately from different
locations. CUDA location is based on CUDA_PATH environment variable, and
cuDNN path is a relative path points to cudnn in nvidia site package.
```
>>> import onnxruntime
>>> import os
>>> os.environ["CUDA_PATH"]
'C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.8'
>>> onnxruntime.preload_dlls(cuda=True, cudnn=False, directory=os.path.join(os.environ["CUDA_PATH"], "bin"))
>>> onnxruntime.preload_dlls(cuda=False, cudnn=True, directory="..\\nvidia\\cudnn\\bin")
>>> onnxruntime.print_debug_info()
...
List of loaded DLLs:
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_adv64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_ops64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_engines_precompiled64_9.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\cufft64_11.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\cublasLt64_12.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\cublas64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_heuristic64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_engines_runtime_compiled64_9.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\cudart64_12.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\numpy.libs\msvcp140-263139962577ecda4cd9469ca360a746.dll
C:\Users\abcd\.conda\envs\py310\msvcp140.dll
C:\Users\abcd\.conda\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
C:\Users\abcd\.conda\envs\py310\vcruntime140_1.dll
C:\Users\abcd\.conda\envs\py310\msvcp140_1.dll
C:\Users\abcd\.conda\envs\py310\vcruntime140.dll
...
```
### Motivation and Context
To address issues mentioned in description of
https://github.com/microsoft/onnxruntime/pull/23674 that onnxruntime
preload might cause conflicts with PyTorch.
Before this change, `import torch` after `onnxruntime.preload_dlls` will
cause issue:
```
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
>>> import torch
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "D:\anaconda3\envs\py310\lib\site-packages\torch\__init__.py", line 137, in <module>
raise err
OSError: [WinError 127] The specified procedure could not be found. Error loading "D:\anaconda3\envs\py310\lib\site-packages\torch\lib\cudnn_adv64_9.dll" or one of its dependencies.
```