onnxruntime
e6023b0c - Create a CUDA based memory arena instead of Cuda Allocator wrapped into BFCArena (#26535)

Commit
39 days ago
Create a CUDA based memory arena instead of Cuda Allocator wrapped into BFCArena (#26535) ### Description This change allows users to better control GPU memory in shared environments with multiple tenants or multiple inference sessions per process. Cuda based memory pool features native allocations on streams, allows trimming the memory on Shrink if enabled and releases memory back to the system based on the user specified parameters. In my limited testing latencies were comparable with running on BFCArena, although your milage and requirements may vary. CudaMemoryPoolArena is enabled via OrtArenaCfg with introducing 3 new parameters: - `use_cuda_mempool` set to 1 to enable - `cuda_mempool_release_threshold` amount of memory to keep cached - `cuda_mempool_bytes_to_keep_on_shrink` the amount of memory to keep on Shrink when being trimmed, allocated memory is not affected. ### Motivation and Context Better GPU memory control in multitenant environments. There are some new options for `onnxruntime_perf_test` introduced in this PR so they may assist clients to figure out the best settings for they case: - `--enable_cuda_mempool 209715200;1048576` with first parameter being `cuda_mempool_release_threshold`. The second `cuda_mempool_bytes_to_keep_on_shrink` can be zero if shrink is not enabled. - `--shrink_arena_between_runs gpu:0` measure perf and memory consumption with shrink. This new allocator strictly speaking does not need `Shrink()` since cuda mempool may release memory on the go according to `cuda_mempool_release_threshold`. Here is some performance numbers gathered when running HF_Bart model. If the CudaMempool release threshold is set too low, latency increases because the system ends up constantly allocating and releasing memory. But as we raise the threshold and allow more memory to stay allocated, latency improves—and we end up using only about half as much memory between runs compared to BFCArena. Running default setup with BFCArena > onnxruntime_perf_test -s -e cuda -I -S 10 -m times -r 100 "hf_Bart_torchscript.onnx" Average inference time cost total: 66.493545 ms P99 Latency: 0.0805385 s Total memory allocated: 1,409,286,144 200 MB release threshold > onnxruntime_perf_test -s -e cuda --enable_cuda_mempool 209715200;0 -I -S 10 -m times -r 100 hf_Bart_torchscript.onnx Average inference time cost total: 77.367473 ms P99 Latency: 0.0931895 s 0.5Gb release threshold > onnxruntime_perf_test -s -e cuda --enable_cuda_mempool 536870912;0 -I -S 10 -m times -r 100 hf_Bart_torchscript.onnx Average inference time cost total: 75.112840 ms P99 Latency: 0.0910992 s 1Gb release threshold > onnxruntime_perf_test -s -e cuda --enable_cuda_mempool 1073741824;0 -I -S 10 -m times -r 100 hf_Bart_torchscript.onnx Average inference time cost total: 66.533892 ms P99 Latency: 0.0761336 s Enabling shrink show we’re retaining only half the memory compared to BFCArena in between inference runs. >CudaMempoolArena::Shrink: pool current_in_use: 709,603,688 reserved size after trim : 738,197,504 bytes. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Author
Parents
Loading