llvm-project
da944e00 - [mlir][tensor] Add shape inference support for `tensor.concat` op. (#140168)

Commit
216 days ago
[mlir][tensor] Add shape inference support for `tensor.concat` op. (#140168) ## description `tensor.concat` requires operands and the result to match on all dimensions except the concatenation dimension. If one operand is already static in those dimensions, the other operands and result type may safely be refined to that same static shape. This PR adds canonicalization patterns to refine `tensor.concat` types and propagate static shapes to other canonicalization patterns through casts. ## example ```mlir %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->tensor<?x12xi32> ``` becomes: ```mlir %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32> %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32> ``` --------- Co-authored-by: Ian Wood <ianwood2024@u.northwestern.edu>
Author
Parents
Loading