llvm-project
[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm
#96181
Merged

[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm #96181

Hsiangkai merged 8 commits into main from users/hsiangkai/winograd-ops
Hsiangkai
Hsiangkai320 days ago

Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

Hsiangkai [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm
4240341b
llvmbot llvmbot added mlir:linalg
llvmbot llvmbot added mlir
llvmbot
llvmbot320 days ago (edited 320 days ago)

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Hsiangkai Wang (Hsiangkai)

Changes

Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)


Patch is 45.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96181.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+114)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+78)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+321)
  • (added) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+248)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+13)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..de1097b6ac27b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..7bf2a5bca037f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto filterElemType = filterType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (filterElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << filterElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned filterRank = filterType.getRank();
+  if (filterRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto inputElemType = inputType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (inputElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << inputElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned inputRank = inputType.getRank();
+  if (inputRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  auto valueType = cast<ShapedType>(getValue().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto valueElemType = valueType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (valueElemType != outputElemType) {
+    return emitOpError() << "expected element type of value " << valueElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned valueRank = valueType.getRank();
+  if (valueRank != 6)
+    return emitOpError() << "expected rank of input is 6";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 4)
+    return emitOpError() << "expected rank of output is 4";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..86e834d51f2fc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,321 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  auto elementType = type.getElementType();
+  auto shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+      elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat
+// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
+// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
+// can convert 6-dimension input data to 3-dimension representation that is
+// suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout
+// (tileH, tileW, H, W, N, F).
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                     Value transformedFilter, Value transformedInput) {
+  auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
+  auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  auto filterShape = filterType.getShape();
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  auto inputElemType = inputType.getElementType();
+  auto inputShape = inputType.getShape();
+
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
+       inputShape[4], filterShape[5]},
+      inputElemType);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                inputElemType);
+
+  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+      ValueRange{init});
+
+  // Expand matmul result
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  auto expandType =
+      RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+                             inputShape[3], inputShape[4], filterShape[5]},
+                            inputElemType);
+  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+      loc, expandType, matmulOp.getResult(0), reassociation);
+  return expandOutput;
+}
+
+Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
+                            RankedTensorType alignedType) {
+  Value alignedInput = rewriter.create<tensor::EmptyOp>(
+      loc, alignedType.getShape(), alignedType.getElementType());
+
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
+  sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
+
+  return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
+                                                offsets, sizes, strides);
+}
+
+Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+                               Value value, RankedTensorType extractedType) {
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto extractedShape = extractedType.getShape();
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
+  sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
+
+  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+                                                 offsets, sizes, strides);
+}
+
+bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(
+      attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+                                            linalg::Conv2DNhwcFhwcOp convOp,
+                                            int64_t m, int64_t r) {
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+  auto inputType = cast<ShapedType>(input.getType());
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto outputType = cast<ShapedType>(output.getType());
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  if (!hasAllOneValues(convOp.getStrides()))
+    return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+  auto filterShape = filterType.getShape();
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  auto inputShape = inputType.getShape();
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  auto outputShape = outputType.getShape();
+  int64_t outputN = outputShape[0];
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t outputF = outputShape[3];
+
+  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+  bool isSupportedFilter = false;
+  if (filterH == filterW && filterH == r)
+    isSupportedFilter = true;
+  if (filterH == r && filterW == 1)
+    isSupportedFilter = true;
+  if (filterH == 1 && filterW == r)
+    isSupportedFilter = true;
+
+  if (!isSupportedFilter)
+    return rewriter.notifyMatchFailure(
+        convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
+  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+      F_2_3, F_4_3, F_2_5};
+
+  TransformMapKeyTy key = {m, r};
+  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+  // If we cannot find the constant transformation matrix, it means we do
+  // not support this configuration yet.
+  if (it == validConfigs.end())
+    return failure();
+
+  // All the criterias are satisfied. We can do Winograd Conv2D.
+  Location loc = convOp.getLoc();
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  int64_t heightM = leftTransform ? m : 1;
+  int64_t widthM = rightTransform ? m : 1;
+  int64_t heightR = leftTransform ? r : 1;
+  int64_t widthR = rightTransform ? r : 1;
+
+  // --- Create operator for filter transform ---
+  Type elementType = filterType.getElementType();
+  int64_t alphaH = heightM + heightR - 1;
+  int64_t alphaW = widthM + widthR - 1;
+  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+  auto retType = RankedTensorType::get(
+      {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+      loc, retType, filter, retValue, m, r);
+
+  // --- Create operator for input transform ---
+
+  // When input size - (r - 1) is not aligned with output tile size, we need to
+  // pad the input data to create the full tiles as tiling.
+  int64_t alignedInputH = tileH * heightM + (heightR - 1);
+  int64_t alignedInputW = tileW * widthM + (widthR - 1);
+  if (alignedInputH != inputH || alignedInputW != inputW) {
+    auto alignedInputType = RankedTensorType::get(
+        {inputN, alignedInputH, alignedInputW, inputC}, elementType);
+    input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
+  }
+
+  retType = RankedTensorType::get(
+      {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
+  retValue =
+ ...
[truncated]
Hsiangkai Hsiangkai requested a review from GeorgeARM GeorgeARM 320 days ago
Hsiangkai Hsiangkai requested a review from ftynse ftynse 320 days ago
Hsiangkai Hsiangkai requested a review from Max191 Max191 320 days ago
Hsiangkai Hsiangkai requested a review from cxy-1993 cxy-1993 320 days ago
Hsiangkai Hsiangkai requested a review from nicolasvasilache nicolasvasilache 320 days ago
Hsiangkai Hsiangkai requested a review from MaheshRavishankar MaheshRavishankar 320 days ago
ftynse
ftynse commented on 2024-06-21
ftynse319 days ago

Looks good in general. Please document functions and address nitpicks, and this will be ready to go!

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
2741LogicalResult WinogradFilterTransformOp::verify() {
2742 auto filterType = cast<ShapedType>(getFilter().getType());
2743 auto outputType = cast<ShapedType>(getOutput().getType());
2744
auto filterElemType = filterType.getElementType();
2745
auto outputElemType = outputType.getElementType();
ftynse319 days ago

Please expand auto unless the type is obvious from the line, e.g., RHS is a cast or is impossible to spell.

Hsiangkai316 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
2749 << outputElemType;
2750 }
2751
2752
unsigned filterRank = filterType.getRank();
2753
if (filterRank != 4)
2754
return emitOpError() << "expected rank of input is 4";
ftynse319 days ago

This can be encoded in ODS by using TensorRankOf<[AnyType], [4]> instead of AnyRankedTensor as type constraint.

Hsiangkai316 days ago

Move the checking to ODS.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
2743 auto outputType = cast<ShapedType>(getOutput().getType());
2744 auto filterElemType = filterType.getElementType();
2745 auto outputElemType = outputType.getElementType();
2746
if (filterElemType != outputElemType) {
2747
return emitOpError() << "expected element type of input " << filterElemType
2748
<< " to match element type of output "
2749
<< outputElemType;
2750
}
ftynse319 days ago

This can be encoded in ODS by using the following trait: AllElementTypesMatch<["filter", "output"]>

Hsiangkai316 days ago

Move the checking to ODS.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
2757 if (outputRank != 6)
2758 return emitOpError() << "expected rank of output is 6";
2759
2760
return success();
ftynse319 days ago

Shouldn't this also check the values of m and r to be in the set of supported ones?

Hsiangkai316 days ago

Winograd operators should be able to support other values of m and r. I already check the valid values in the OpRewritePattern.
In the verify function, I added checking for the relationship between dimension values.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
36constexpr TransformMapKeyTy F_4_3{4, 3};
37constexpr TransformMapKeyTy F_2_5{2, 5};
38
39
Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
ftynse319 days ago

Please document top-level entities.

Also declare functions used within a single translation unit as static. https://llvm.org/docs/CodingStandards.html#anonymous-namespaces

Hsiangkai316 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
25
26using TransformMapKeyTy = std::pair<int, int>;
27
28
// We use F(m, r) to define the size of minimal filtering algorithms.
ftynse319 days ago

MLIR uses /// for top-level documentation comments.

Hsiangkai316 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
40 auto type = cast<ShapedType>(data.getType());
41 auto elementType = type.getElementType();
42 auto shape = type.getShape();
43
auto collapseType = RankedTensorType::get(
44
{shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
45
elementType);
ftynse319 days ago

Please add an assertion about the shape being static. I see this is only called in such a case, but we may eventually want to generalize this to support dynamic shapes.

Hsiangkai316 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
49}
50
51// This function generates linalg.batch_matmul to multiply input with filter.
52
// linalg.batch_matmul only supports 3-dimension data sets. We can treat
ftynse319 days ago
Suggested change
// linalg.batch_matmul only supports 3-dimension data sets. We can treat
// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
ftynse319 days ago

Same below.

Hsiangkai316 days ago

Done.

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
59//
60// We get
61//
62
// %collapsed_input = tensor.collapse_shape %input
63
// %collapsed_filter = tensor.collapse_shape %filter
64
// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
65
// %expanded_ret = tensor.expand_shape %ret
ftynse319 days ago

Note that you can use linalg.generic with explicit indexing maps and multiple batch dimensions that will remove the need to manipulate shapes. Not insisting on this change, but something to consider and maybe comment as to why the choice was made one way or another.

Hsiangkai316 days ago

Be honest, it is difficult for me to reason about using linalg-generic with indexing maps and multiple batch dimensions. In addition, we have optimisation for batched matmul. It is easier to match the operator directly, instead of finding the matmul pattern inside linalg.generic.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
71 auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
72 auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
73
74
// Batched matrix multiply
ftynse319 days ago
Suggested change
// Batched matrix multiply
// Batched matrix multiply.

Nit: Please end sentences in comments with a full stop. Here and below.

Hsiangkai316 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
112
113 auto valueType = cast<ShapedType>(value.getType());
114 auto valueShape = valueType.getShape();
115
SmallVector<OpFoldResult, 4> sizes;
116
sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
117
sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
118
sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
119
sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
ftynse319 days ago

Nit: I'm pretty sure we have a getAsOpFoldResult somewhere that does this.

Hsiangkai316 days ago

Thanks for your tips. It is more elegant to use getAsOpFoldResult.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
222 int64_t heightR = leftTransform ? r : 1;
223 int64_t widthR = rightTransform ? r : 1;
224
225
// --- Create operator for filter transform ---
ftynse319 days ago
Suggested change
// --- Create operator for filter transform ---
// --- Create operation for filter transform ---

Here and below.

Hsiangkai316 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/test/Dialect/Linalg/winograd-conv2d.mlir
2
3func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
4 %0 = tensor.empty() : tensor<2x4x4x2xf32>
5
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
6
^bb0(%in: f32, %out: f32):
7
linalg.yield %in : f32
8
} -> tensor<2x4x4x2xf32>
ftynse319 days ago

Why do we have this in the test?

Hsiangkai316 days ago

Remove the unrelated initialisation.

Conversation is marked as resolved
Show resolved
mlir/test/Dialect/Linalg/winograd-conv2d.mlir
15// CHECK-LABEL: func.func @conv2d_4x4_3x3
16// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
17// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
18
// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
ftynse319 days ago

We don't care about these operations being strictly on the next line, let's not overconstrain tests.

Hsiangkai316 days ago

Removed.

Hsiangkai Address ftynse's comments
bbb6542f
ftynse
ftynse316 days ago

High-level design question: what are we getting from introducing these ops instead of just directly implementing the transformation and emitting more Linalg? Are the ops getting transformed?

Max191
Max191 requested changes on 2024-06-24
Max191316 days ago

Nice improvements so far! A few comments:

  1. The shapes for the filter transform op and the batch_matmul don't quite make sense to me. See the comments about the filter transform.
  2. When the input image tensor is padded, the padded part of the tensor needs to be filled with zeros
  3. It would be great to support NCHW conv layouts. This can come as a later PR, but would be really useful to have.
Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
150 attr, [](const APInt &element) { return element.getSExtValue() == 1; });
151}
152
153
/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
154
/// linalg.winograd_*_transform ops.
155
static FailureOr<Operation *>
156
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
157
int64_t m, int64_t r) {
Max191316 days ago

Can we also support NCHW conv types? It is okay if it comes as a later PR, but NCHW convs are important for winograd performance. The N and C dimensions are tiled to 1 with winograd, and a slice is extracted along HW in the decomposition, so having W be innermost allows for vectorized loads.

Hsiangkai314 days ago👍 1

How about to have another PR later for different data layout.

Max191314 days ago

Can we add a TODO comment somewhere?

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
106 return expandOutput;
107}
108
109
/// Create an empty tensor with alignedType and insert the value into the
110
/// created empty tensor with aligned size.
111
static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
112
Value value, RankedTensorType alignedType) {
Max191316 days ago

This needs to actually pad the value with zeros. Inserting directly into a tensor.empty means the padding is uninitialized memory. Instead, could you create a tensor.pad here? This will fix the issue of uninitialized memory, and the tensor.pad can also fold into the tensor.pad that usually comes as a producer to the conv input image.

Hsiangkai314 days ago

Modified to use tensor.pad.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
229 int64_t heightR = leftTransform ? r : 1;
230 int64_t widthR = rightTransform ? r : 1;
231
232
// --- Create operation for filter transform ---
233
Type elementType = filterType.getElementType();
234
int64_t alphaH = heightM + heightR - 1;
235
int64_t alphaW = widthM + widthR - 1;
236
int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
237
int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
238
auto retType = RankedTensorType::get(
239
{tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
Max191316 days ago

This looks wrong to me. I would expect the filter transform op return type to be a 4D tensor of shape (alphaH)x(alphaW)xCxF. This seems to be adding some extra tileH and tileW dimensions based on the output size of the convolution. Is there a reason for having these tileH and tileW dims?

Hsiangkai314 days ago

You are right. There is no need to have tileH and tileW for filter.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
88 {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
89 inputShape[4], filterShape[5]},
90 inputElemType);
91
Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
92
inputElemType);
Max191316 days ago

Some conv_2d ops can have implicit element type promotion, increasing the bitwidth of the element type. The matmul output element type should match the output element type of the convolution instead of the input image element type, since they may not be the same.

Hsiangkai314 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
84 Type inputElemType = inputType.getElementType();
85 ArrayRef<int64_t> inputShape = inputType.getShape();
86
87
auto matmulType = RankedTensorType::get(
88
{inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
89
inputShape[4], filterShape[5]},
90
inputElemType);
Max191316 days ago

In the below comment, I believe the filter transform shape is not quite right, which may be why the matmul shape is batched in this way, but I would expect this to be:

Suggested change
auto matmulType = RankedTensorType::get(
{inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
inputShape[4], filterShape[5]},
inputElemType);
auto matmulType = RankedTensorType::get(
{inputShape[2] * inputShape[3],
inputShape[0] * inputShape[1] * inputShape[4],
filterShape[5]},
inputElemType);

The input transform result shape is (tileH, tileW, alphaH, alphaW, inputN, inputC), and filter transform shape that I would expect to see is (alphaH, alphaW, filterC, filterF). The shared dimensions in this case are alphaH and alphaW, so those should be batch. tileH and tileW would be M dimensions of the batch matmul.

An additional suggestion based on this is to change the layout of the input transform result from (tileH, tileW, alphaH, alphaW, inputN, inputC) to (alphaH, alphaW, tileH, tileW, inputN, inputC), in order to have batch dimensions be outermost on the batch_matmul. This can help with performance of the matmul in some cases, and it means you can use the linalg.batch_matmul named op instead of a linalg.generic op. With this different layout, the matmul shape would be:

Suggested change
auto matmulType = RankedTensorType::get(
{inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
inputShape[4], filterShape[5]},
inputElemType);
auto matmulType = RankedTensorType::get(
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4],
filterShape[5]},
inputElemType);
Hsiangkai314 days ago

Updated.

Max191
Max191316 days ago

High-level design question: what are we getting from introducing these ops instead of just directly implementing the transformation and emitting more Linalg? Are the ops getting transformed?

These ops carry an implicit level of tiling, where the non-input tile dimensions (N, C, and expanded H and W for input transform) get tiled to 1. Then, the ops can be decomposed into a series of linalg.matmul ops. The main benefit of having these ops is that they are a canonical form of the operation that doesn't have to be split into multiple ops within a loop nest. This makes tiling neater, especially when trying to tile and fuse the winograd ops.

Hsiangkai Address Max191's comments
db8e7e7d
Hsiangkai Hsiangkai requested a review from Max191 Max191 314 days ago
Max191
Max191 requested changes on 2024-06-26
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
27342734 return SmallVector<Value>{result};
27352735}
27362736
2737
//===----------------------------------------------------------------------===//
Max191314 days ago

These verifiers will not work for dynamic shapes. Can you support dynamic cases? The transform is only supported for static shapes right now, but shapes can become dynamic when tiling.

You can create an expected output shape from the input, allowing dynamic dims, and compare with the actual output shape. This helper may be useful:

LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,

This way will also make it easy to check that the batch/channel dimensions match for the input and output.

Hsiangkai310 days ago

Can we support static shapes in the upstream first? I added a TODO for dynamic cases.

stellaraccident310 days ago

It has a tendency to not get done and not having the shape handling can mask other issues. Usually it's better to do it right, do it once.

Hsiangkai306 days ago

Thanks for your review. I updated the verify() functions to consider dynamic shapes. I also added test cases for dynamic shapes.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
110
111/// Create an empty tensor with alignedType and insert the value into the
112/// created empty tensor with aligned size.
113
static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
Max191314 days ago

nit: Rename to padToAlignedTensor

Hsiangkai310 days ago

Done.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
126 auto alignedType = RankedTensorType::get(alignedShape, elementType);
127 Value pad_value = rewriter.create<arith::ConstantOp>(
128 loc, elementType, rewriter.getZeroAttr(elementType));
129
return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
Max191314 days ago

nit: Can you use this util?

Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,

Hsiangkai310 days ago

Done.

Hsiangkai Add more tests in Linalg/roundtrip.mlir and Linalg/invalid.mlir
f018ec0b
Hsiangkai Address more comments
11a4ee23
Hsiangkai Hsiangkai requested a review from dcaballe dcaballe 310 days ago
Hsiangkai Hsiangkai requested a review from rengolin rengolin 310 days ago
Hsiangkai Hsiangkai requested a review from Max191 Max191 310 days ago
Hsiangkai Consider dynamic shapes in verify functions
65883dac
Max191
Max191 approved these changes on 2024-07-08
Max191302 days ago

Looks good, thanks! Just make sure to switch to ShapedType::kDynamic before landing.

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
2785 SmallVector<int64_t> expectedOutputShape(6, inputH);
2786 if (ShapedType::isDynamic(inputH)) {
2787 expectedOutputShape[0] = tileSize;
2788
expectedOutputShape[2] = -1;
Max191302 days ago

Instead of -1 can you use ShapedType::kDynamic?

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
27892792 }
2790
2791 if (rightTransform) {
2792 int64_t tileW = (inputW - (r - 1)) / m;
2793 if (inputW != tileW * m + (r - 1))
2794 return emitOpError("input width cannot be tiled in full tile size");
2795 if (tileW != outputTileW)
2796 return emitOpError("number of output width tiles is not correct");
2797 if (outputW != m + r - 1)
2798 return emitOpError("expect output width equals to tile size");
2793 if (ShapedType::isDynamic(inputW)) {
2794 expectedOutputShape[1] = tileSize;
2795
expectedOutputShape[3] = -1;
Max191302 days ago

ditto, use ShapedType::kDynamic

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
28232826
2824 if (!leftTransform && !rightTransform)
2825 return failure();
2826
2827 if (leftTransform) {
2828 if (valueH != m + r - 1)
2827 SmallVector<int64_t> expectedOutputShape(4, valueH);
2828 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
2829
expectedOutputShape[1] = -1;
Max191302 days ago

ShapedType::kDynamic

Conversation is marked as resolved
Show resolved
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
2833 expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
28322834 }
2833
2834 if (rightTransform) {
2835 if (valueW != m + r - 1)
2835 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
2836
expectedOutputShape[2] = -1;
Max191302 days ago

ShapedType::kDynamic

Hsiangkai use ShapedType::kDynamic
97329fa4
Hsiangkai Merge branch 'main' into users/hsiangkai/winograd-ops
4a46b7bf
Hsiangkai Hsiangkai merged 7d246e84 into main 300 days ago
Hsiangkai Hsiangkai deleted the users/hsiangkai/winograd-ops branch 300 days ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone