pytorch
f2f8eeea - Inductor: fix Conv output stride for dynamic shapes (#121400)

Commit
309 days ago
Inductor: fix Conv output stride for dynamic shapes (#121400) Fixes https://github.com/pytorch/pytorch/issues/120873. Fixes the output stride of Conv in the case of dynamic shapes. The previous logic in inductor assumed that the output stride of Conv is always channels last while it is actually contiguous if `dynamic_shapes and is_contiguous_storage_and_layout(x)`. ### Static shape In static shape cases, since weight is prepacked (`weight_t.is_mkldnn()` will be `true`), we'll always force output to be channels last in the Conv kernel, thus it's fine to have the assumption in Inductor that the output stride of Conv is always channels last. https://github.com/pytorch/pytorch/blob/96ed37ac13366cc9a7e6645b8955061d0a14f80b/aten/src/ATen/native/mkldnn/Conv.cpp#L357-L358 ### Dynamic shape In dynamic shape cases, we won't do weight prepack for Conv, in this case, the Conv kernel decides the output layout based on the input and weight layout. https://github.com/pytorch/pytorch/blob/96ed37ac13366cc9a7e6645b8955061d0a14f80b/torch/_inductor/fx_passes/mkldnn_fusion.py#L1024-L1025 For input with `channels = 1`, like tensor of size `(s0, 1, 28, 28)` and stride `(784, 784, 28, 1)`, in Inductor, with `req_stride_order` in channels last order, the `require_stride_order` on `x` of such size and stride won't change the stride of the tensor since stride for dimensions of size 1 is ignored https://github.com/pytorch/pytorch/blob/96ed37ac13366cc9a7e6645b8955061d0a14f80b/torch/_inductor/ir.py#L5451 While in Conv kernel, such tensor is consider it as **contiguous** tensor instead of channels last tensor thus the output of the Conv kernel will be in contiguous format. https://github.com/pytorch/pytorch/blob/96ed37ac13366cc9a7e6645b8955061d0a14f80b/aten/src/ATen/native/ConvUtils.h#L396-L404 To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121400 Approved by: https://github.com/jgong5, https://github.com/jansel
Author
Committer
Parents
Loading