llvm-project
bce951c5 - [mlir][linalg] Update vectorization logic for linalg.unpack (#149156)

Commit
156 days ago
[mlir][linalg] Update vectorization logic for linalg.unpack (#149156) This PR makes sure that we don't generate unnecessary `tensor.empty` when vectorizing `linalg.unpack`. To better visualize the changes implemented here, consider this IR: ```mlir func.func @example( %source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> { %res = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32> return %res : tensor<64x127xf32> } ``` Below is the output after vectorization, BEFORE and AFTER this PR. BEFORE (note `tensor.empty` and the fact that `%arg1` is not used): ```mlir func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> %1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32> %2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32> %3 = tensor.empty() : tensor<64x127xf32> %c0_0 = arith.constant 0 : index %4 = vector.transfer_write %2, %3[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32> return %4 : tensor<64x127xf32> } ``` AFTER (note that `%arg1` is correctly used): ```mlir func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> %1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32> %2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32> %c0_0 = arith.constant 0 : index %3 = vector.transfer_write %2, %arg1[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32> return %3 : tensor<64x127xf32> } ```
Author
Parents
Loading