@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Hsiangkai Wang (Hsiangkai)
Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation.
The formula of Winograd Conv2D algorithm is
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A
The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
Patch is 45.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96181.diff
7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..de1097b6ac27b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+ let summary = "Winograd filter transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of filter
+ transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$filter,
+ AnyRankedTensor:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $filter `:` type($filter) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+ let summary = "Winograd input transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of input
+ transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$input,
+ AnyRankedTensor:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $input `:` type($input) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+ let summary = "Winograd output transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of output
+ transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$value,
+ AnyRankedTensor:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $value `:` type($value) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+ int64_t r);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..7bf2a5bca037f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+ auto filterType = cast<ShapedType>(getFilter().getType());
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ auto filterElemType = filterType.getElementType();
+ auto outputElemType = outputType.getElementType();
+ if (filterElemType != outputElemType) {
+ return emitOpError() << "expected element type of input " << filterElemType
+ << " to match element type of output "
+ << outputElemType;
+ }
+
+ unsigned filterRank = filterType.getRank();
+ if (filterRank != 4)
+ return emitOpError() << "expected rank of input is 4";
+
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != 6)
+ return emitOpError() << "expected rank of output is 6";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+ auto inputType = cast<ShapedType>(getInput().getType());
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ auto inputElemType = inputType.getElementType();
+ auto outputElemType = outputType.getElementType();
+ if (inputElemType != outputElemType) {
+ return emitOpError() << "expected element type of input " << inputElemType
+ << " to match element type of output "
+ << outputElemType;
+ }
+
+ unsigned inputRank = inputType.getRank();
+ if (inputRank != 4)
+ return emitOpError() << "expected rank of input is 4";
+
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != 6)
+ return emitOpError() << "expected rank of output is 6";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+ auto valueType = cast<ShapedType>(getValue().getType());
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ auto valueElemType = valueType.getElementType();
+ auto outputElemType = outputType.getElementType();
+ if (valueElemType != outputElemType) {
+ return emitOpError() << "expected element type of value " << valueElemType
+ << " to match element type of output "
+ << outputElemType;
+ }
+
+ unsigned valueRank = valueType.getRank();
+ if (valueRank != 6)
+ return emitOpError() << "expected rank of input is 6";
+
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != 4)
+ return emitOpError() << "expected rank of output is 4";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Transforms.cpp
TransposeConv2D.cpp
Vectorization.cpp
+ WinogradConv2D.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..86e834d51f2fc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,321 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+ auto type = cast<ShapedType>(data.getType());
+ auto elementType = type.getElementType();
+ auto shape = type.getShape();
+ auto collapseType = RankedTensorType::get(
+ {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+ elementType);
+ SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+ reassociation);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat
+// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
+// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
+// can convert 6-dimension input data to 3-dimension representation that is
+// suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout
+// (tileH, tileW, H, W, N, F).
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+ Value transformedFilter, Value transformedInput) {
+ auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
+ auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
+
+ // Batched matrix multiply
+ auto filterType = cast<ShapedType>(transformedFilter.getType());
+ auto filterShape = filterType.getShape();
+ auto inputType = cast<ShapedType>(transformedInput.getType());
+ auto inputElemType = inputType.getElementType();
+ auto inputShape = inputType.getShape();
+
+ auto matmulType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
+ inputShape[4], filterShape[5]},
+ inputElemType);
+ Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ inputElemType);
+
+ auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+ ValueRange{init});
+
+ // Expand matmul result
+ SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+ auto expandType =
+ RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+ inputShape[3], inputShape[4], filterShape[5]},
+ inputElemType);
+ auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandType, matmulOp.getResult(0), reassociation);
+ return expandOutput;
+}
+
+Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
+ RankedTensorType alignedType) {
+ Value alignedInput = rewriter.create<tensor::EmptyOp>(
+ loc, alignedType.getShape(), alignedType.getElementType());
+
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ auto valueType = cast<ShapedType>(value.getType());
+ auto valueShape = valueType.getShape();
+ SmallVector<OpFoldResult, 4> sizes;
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
+
+ return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
+ offsets, sizes, strides);
+}
+
+Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+ Value value, RankedTensorType extractedType) {
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ auto extractedShape = extractedType.getShape();
+ SmallVector<OpFoldResult, 4> sizes;
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
+
+ return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+ offsets, sizes, strides);
+}
+
+bool hasAllOneValues(DenseIntElementsAttr attr) {
+ return llvm::all_of(
+ attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp convOp,
+ int64_t m, int64_t r) {
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+ auto inputType = cast<ShapedType>(input.getType());
+ auto filterType = cast<ShapedType>(filter.getType());
+ auto outputType = cast<ShapedType>(output.getType());
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ if (!hasAllOneValues(convOp.getStrides()))
+ return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+ auto filterShape = filterType.getShape();
+ int64_t filterF = filterShape[0];
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t filterC = filterShape[3];
+ auto inputShape = inputType.getShape();
+ int64_t inputN = inputShape[0];
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t inputC = inputShape[3];
+ auto outputShape = outputType.getShape();
+ int64_t outputN = outputShape[0];
+ int64_t outputH = outputShape[1];
+ int64_t outputW = outputShape[2];
+ int64_t outputF = outputShape[3];
+
+ // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+ bool isSupportedFilter = false;
+ if (filterH == filterW && filterH == r)
+ isSupportedFilter = true;
+ if (filterH == r && filterW == 1)
+ isSupportedFilter = true;
+ if (filterH == 1 && filterW == r)
+ isSupportedFilter = true;
+
+ if (!isSupportedFilter)
+ return rewriter.notifyMatchFailure(
+ convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+ // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
+ static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+ F_2_3, F_4_3, F_2_5};
+
+ TransformMapKeyTy key = {m, r};
+ auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+ // If we cannot find the constant transformation matrix, it means we do
+ // not support this configuration yet.
+ if (it == validConfigs.end())
+ return failure();
+
+ // All the criterias are satisfied. We can do Winograd Conv2D.
+ Location loc = convOp.getLoc();
+
+ // For F(m x 1, r x 1), we only need to do left side transform.
+ bool leftTransform = filterH != 1;
+ // For F(1 x m, 1 x r), we only need to do right side transform.
+ bool rightTransform = filterW != 1;
+ int64_t heightM = leftTransform ? m : 1;
+ int64_t widthM = rightTransform ? m : 1;
+ int64_t heightR = leftTransform ? r : 1;
+ int64_t widthR = rightTransform ? r : 1;
+
+ // --- Create operator for filter transform ---
+ Type elementType = filterType.getElementType();
+ int64_t alphaH = heightM + heightR - 1;
+ int64_t alphaW = widthM + widthR - 1;
+ int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+ int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+ auto retType = RankedTensorType::get(
+ {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
+ Value retValue =
+ rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+ auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+ loc, retType, filter, retValue, m, r);
+
+ // --- Create operator for input transform ---
+
+ // When input size - (r - 1) is not aligned with output tile size, we need to
+ // pad the input data to create the full tiles as tiling.
+ int64_t alignedInputH = tileH * heightM + (heightR - 1);
+ int64_t alignedInputW = tileW * widthM + (widthR - 1);
+ if (alignedInputH != inputH || alignedInputW != inputW) {
+ auto alignedInputType = RankedTensorType::get(
+ {inputN, alignedInputH, alignedInputW, inputC}, elementType);
+ input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
+ }
+
+ retType = RankedTensorType::get(
+ {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
+ retValue =
+ ...
[truncated]
Looks good in general. Please document functions and address nitpicks, and this will be ready to go!
59 | // | ||
60 | // We get | ||
61 | // | ||
62 | // %collapsed_input = tensor.collapse_shape %input | ||
63 | // %collapsed_filter = tensor.collapse_shape %filter | ||
64 | // %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter | ||
65 | // %expanded_ret = tensor.expand_shape %ret |
Note that you can use linalg.generic
with explicit indexing maps and multiple batch dimensions that will remove the need to manipulate shapes. Not insisting on this change, but something to consider and maybe comment as to why the choice was made one way or another.
Be honest, it is difficult for me to reason about using linalg-generic
with indexing maps and multiple batch dimensions. In addition, we have optimisation for batched matmul. It is easier to match the operator directly, instead of finding the matmul pattern inside linalg.generic
.
High-level design question: what are we getting from introducing these ops instead of just directly implementing the transformation and emitting more Linalg? Are the ops getting transformed?
Nice improvements so far! A few comments:
High-level design question: what are we getting from introducing these ops instead of just directly implementing the transformation and emitting more Linalg? Are the ops getting transformed?
These ops carry an implicit level of tiling, where the non-input tile dimensions (N
, C
, and expanded H
and W
for input transform) get tiled to 1. Then, the ops can be decomposed into a series of linalg.matmul
ops. The main benefit of having these ops is that they are a canonical form of the operation that doesn't have to be split into multiple ops within a loop nest. This makes tiling neater, especially when trying to tile and fuse the winograd ops.
Nice work! Just a couple comments and some nits.
Can you add some tests to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/roundtrip.mlir and https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/invalid.mlir?
2734 | 2734 | return SmallVector<Value>{result}; | |
2735 | 2735 | } | |
2736 | 2736 | ||
2737 | //===----------------------------------------------------------------------===// |
These verifiers will not work for dynamic shapes. Can you support dynamic cases? The transform is only supported for static shapes right now, but shapes can become dynamic when tiling.
You can create an expected output shape from the input, allowing dynamic dims, and compare with the actual output shape. This helper may be useful:
This way will also make it easy to check that the batch/channel dimensions match for the input and output.
Can we support static shapes in the upstream first? I added a TODO for dynamic cases.
It has a tendency to not get done and not having the shape handling can mask other issues. Usually it's better to do it right, do it once.
Thanks for your review. I updated the verify()
functions to consider dynamic shapes. I also added test cases for dynamic shapes.
Looks good, thanks! Just make sure to switch to ShapedType::kDynamic
before landing.
Login to write a write a comment.
Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation.
The formula of Winograd Conv2D algorithm is
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A
The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)