onnxruntime
6c63c1c9 - Multiple Gather to Split Fusion (#13095)

Commit
3 years ago
Multiple Gather to Split Fusion (#13095) For below code in some transformers models: ``` fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] ``` The exported graph will contains 3 Gather nodes, currently ORT's GatherGrad CUDA implementation is slow. This pattern can be fused to use one Split, so that we can launch less kernels for the compute, the perf of Split/Concat (for grad) is also better than Gather/GatherGrad. In a real example, one GatherGrad will take 15ms and there are 3 for each layer in the graph, after the fusion, one Concat takes only 35us. The total time of a step is improved from 1.5s to 0.4s.
Author
Parents
Loading