llvm-project
6d9ce676 - [mlir][bufferization] implement BufferizableOpInterface for concat op (#140171)

Commit
174 days ago
[mlir][bufferization] implement BufferizableOpInterface for concat op (#140171) Lowers `tensor.concat` to an alloc with a series of `memref.copy` ops to copy the operands to the alloc. Example: ```mlir func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> { %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> return %t : tensor<16xf32> } ``` Produces ```mlir module { func.func @tensor.concat(%arg0: tensor<8xf32>) -> tensor<16xf32> { // initialization %0 = bufferization.to_memref %arg0 : tensor<8xf32> to memref<8xf32> %alloc = memref.alloc() {alignment = 64 : i64} : memref<8xf32> memref.copy %0, %alloc : memref<8xf32> to memref<8xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8xf32> memref.copy %0, %alloc_0 : memref<8xf32> to memref<8xf32> %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<16xf32> // one copy for each operand %subview = memref.subview %alloc_1[0] [8] [1] : memref<16xf32> to memref<8xf32, strided<[1]>> memref.copy %alloc, %subview : memref<8xf32> to memref<8xf32, strided<[1]>> %subview_2 = memref.subview %alloc_1[8] [8] [1] : memref<16xf32> to memref<8xf32, strided<[1], offset: 8>> memref.copy %alloc_0, %subview_2 : memref<8xf32> to memref<8xf32, strided<[1], offset: 8>> %1 = bufferization.to_tensor %alloc_1 : memref<16xf32> to tensor<16xf32> return %1 : tensor<16xf32> } } ``` This is my first time implementing BufferizableOpInterface, so I'm looking for some advice on how I can: 1. Clean up my implementation. 2. Avoid duplicate `memref.copy` ops in the `// initialization` section above when handling duplicate `tensor.concat` operands. --------- Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
Author
Parents
Loading