onnxruntime
2bda3fd3 - Gather to Slice Fusion (#13599)

Commit
3 years ago
Gather to Slice Fusion (#13599) This PR is to optimize the running for below code from Huggingface's XLNet model. ``` x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long)) ``` The code will be exported to Range->Gather, which can be fused to a Slice Op. Slice kernel is much faster than Gather, especially for backward run. The main reason is for Gather, the data in indices can be duplicated so that it needs sum during backward, but Slice node cannot have such case. Use Huggingface's XLNet model for profiling. - Before the fuse forward, ~753us ![image](https://user-images.githubusercontent.com/11661208/200758439-63f2f9b5-9610-4df8-98c8-a1ad4dc62f4e.png) backward, ~46101us ![image](https://user-images.githubusercontent.com/11661208/200758530-fe16a8ec-ea8f-4b79-b3ac-386b72ba1670.png) - After the fuse forward, ~627us ![image](https://user-images.githubusercontent.com/11661208/200758654-ab9a6068-c45d-40f4-9c71-3862a56732f8.png) backward, ~677us ![image](https://user-images.githubusercontent.com/11661208/200758833-aab1b8e1-1b5d-4e55-88cf-03c2a1d9d42b.png)
Author
Parents
Loading