onnxruntime
9d7e6d53 - [CPU/CUDA ep] Improve DeformConv op performance (#27824)

Commit
15 days ago
[CPU/CUDA ep] Improve DeformConv op performance (#27824) ### Description Improve DeformConv op performance ### Motivation and Context This PR consolidates a series of optimizations targeting the `DeformConv` (Deformable Convolution) operator across both CPU and CUDA execution providers. * **For CPU:** The previous implementation suffered from bottlenecks due to redundant computations, lack of vectorization in bilinear sampling, and sub-optimal thread pool utilization. This overhaul redesigns the memory layout and execution pipeline to maximize SIMD opportunities and harden memory safety. * **For GPU:** The batched GEMM operation previously relied on an intermediate buffer and a custom scatter kernel to format the output, which consumed extra memory and kernel launch overhead. This update introduces a zero-copy approach. --- #### 1. CPU Optimizations & Refactoring The CPU execution path has been heavily refactored to minimize branching in hot paths, maximize vectorization, and safely handle edge cases. | Feature / Optimization | Description | Key Benefit | | :--- | :--- | :--- | | **AoSoA Bilinear Sampling Plan** | Replaced on-the-fly interpolation with a precomputed sampling plan using an 8-lane Array-of-Structures-of-Arrays (AoSoA) layout (`kPlanAoSoALanes`). | Perfectly aligns with 256-bit AVX2 vectors, enabling highly efficient SIMD unrolling during the `im2col` gathering phase. | | **Kernel Metadata Caching** | Introduced `DeformConvKernelMetaCacheData` to cache static convolution geometry (e.g., `kH`, `kW`, `padding`, `dilation`). | Eliminates the O(kernel_size) overhead of reallocating and recomputing base offsets on every single `Compute()` step. | | **Fast Math & Branchless Logic** | Implemented a custom `DeformConvFastFloor` and utilized an inverted bounds check with bitwise operations to evaluate all four corners simultaneously. | Removes expensive `std::floor` calls and unpredictable branches from the operator's hottest path. | | **Enhanced Parallelization** | Flattened the bilinear sampling plan build tasks across spatial pixels. | Allows `concurrency::ThreadPool::TryParallelFor` to split fine-grained work effectively, drastically improving thread pool scaling. | | **Hardened Bounds Checking** | Introduced compute-time bounds checks using `CheckedMulSizeT` and `CheckedBatchSpan`. | Ensures batch indexing and stride calculations stay within the addressable `size_t` range, preventing integer overflow vulnerabilities. | | **Bias Addition Refactoring** | Refactored bias addition to avoid expensive `div`/`mod` operations, applying `ORT_CPU_RESTRICT` and force-inlining. | Maximizes memory throughput and instruction pipelining during the final bias addition phase. | --- #### 2. GPU (CUDA) Optimizations The CUDA implementation was optimized to reduce memory footprint and eliminate unnecessary kernel launches. * **Zero-Copy GEMM Output:** Removed the temporary `gemm_output_buffer` allocation entirely. By carefully configuring the `stride_c` parameter (`stride_c_y = M * output_image_size`), the `cublasGemmStridedBatchedHelper` now writes the computed output directly into the correct NCHW memory layout of the final `Y` tensor. * **Kernel Elimination:** Completely removed the `DeformConvCopyGemmOutputRowMajorToNCHW` custom kernel and its associated dispatch logic. This reduces kernel launch overhead, lowers GPU memory bandwidth pressure, and simplifies the overall CUDA execution pipeline. * **Reduced Memory Footprint:** Updated the `bytes_per_image` calculation for workspace memory to reflect the removal of the GEMM output buffer. This allows the operator to potentially process more images in parallel under the same memory constraints. --- #### 3. Changed - **Batch chunking:** Chunk size `k` is chosen so that the number of outer rounds is minimized under the temp-memory cap; **`k` does not have to divide `N`**. The host loop uses `cur_parallel = min(k, N - b)`, so the last chunk may be smaller. This is the intended default behavior for this EP (not yet in a formal release). - **Kernel-size templates:** Im2col is specialized for **1×1, 3×3, and 7×7**; other sizes (including **5×5**) use the **dynamic** `kH`/`kW` path. Rationale: 5×5 is less common in current stacks (often replaced by stacked 3×3); specializing 7×7 targets common large-kernel cases. Older DCN/detection models that still use **5×5** deformable conv will take the dynamic path—correctness is unchanged; only compile-time unrolling differs. - **Add aliasing flags:** Updated DeformConv aliasing comments to make the stronger guarantee explicit: if output `Y` overlaps any input buffer, results can be incorrect regardless of `restrict`, because output writes may clobber source elements before they are fully consumed. `restrict` further tightens this by introducing undefined behavior when aliasing assumptions are violated. --- ### Summary In the current implementation, CPU performance is 33x (main branch is 15x) that of TorchVision. If we were to implement AVX2/AVX512 optimizations from scratch, we could achieve a 36x performance boost. However, I haven’t found any similar reference code in the ONNX Runtime repository. This PR also significantly improves parallelism: <img width="540" height="332" alt="image" src="https://github.com/user-attachments/assets/d4f670bd-dde3-43f1-b597-4471bfde005b" /> _Both ort and tv are configured with 16 threads_ ### Open Question for Reviewers **Regarding CUDA Temporary Memory Allocation:** Currently, the effective maximum temporary memory for CUDA is calculated using a heuristic (`total_global_mem * 0.1` or similar logic in `GetDeformConvEffectiveMaxTempBytes`). While the removal of `gemm_output_buffer` has reduced the memory footprint per image, I am not entirely certain if this 10% threshold is still the most appropriate value for balancing parallel image processing (`n_parallel_imgs`) against overall VRAM consumption in large models. I would appreciate any feedback or suggestions on whether we should tune this threshold, or if there's a more robust way to dynamically determine the optimal temporary workspace size for `DeformConv` in ORT.
Author
Parents
Loading