pytorch
40df6e16 - [ONNX] Simplify repeat_intereleave export for scalar-valued 'repeat' (#100575)

Commit
2 years ago
[ONNX] Simplify repeat_intereleave export for scalar-valued 'repeat' (#100575) This PR simplifies the ONNX export of torch.repeat_interleave when 'repeat' is a scalar value (so each index in the input is repeated the same number of times). (Issue #100438) Here is a before/after of a simple model export: ```python # Model + export code import torch class RepeatInterleaveModel(torch.nn.Module): def forward(self, x): return x.repeat_interleave(2, dim=-1) args = (torch.rand((2, 2, 16)),) model = RepeatInterleaveModel() torch.onnx.export(model, args, "repeat_interleave.onnx", opset_version=17) ``` **Before (static shapes)** ![repeat_interleave onnx(1)](https://user-images.githubusercontent.com/46343317/236014996-00726832-1e76-4fb4-950d-4b54cc5cc20c.png) ----- **Before (dynamic shapes, second graph is Loop body)** <p float="left"> <img src="https://user-images.githubusercontent.com/46343317/236029895-20b0ae0a-240f-466d-bb01-e619ec5967ad.png" width="45%" /> <img src="https://user-images.githubusercontent.com/46343317/236029915-e67b808a-029b-4997-bc05-1ce59eec409a.png" width="47%" /> </p> ----- **After (for both static and dynamic shapes)** <img src="https://user-images.githubusercontent.com/46343317/236015235-633811cb-09a2-435d-a293-1b2bcb7dea50.png" width="66%" /> ----- This PR also fixes a bug where the exporter throws an expection when the input has dynamic shapes and the 'dim' parameter is not specified to torch.repeat_interleave. Also adds a new testcase to cover this. (Issue #100429) Fixes #100438 and #100429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100575 Approved by: https://github.com/BowenBao
Author
Committer
Parents
Loading