llvm-project
e22508ea - [mlir][vector] Update `CombineContractBroadcastMask` (#140050)

Commit
128 days ago
[mlir][vector] Update `CombineContractBroadcastMask` (#140050) This patch updates `CombineContractBroadcastMask` to inherit from `MaskableOpRewritePattern`, enabling it to handle masked `vector.contract` operations. The pattern rewrites: ```mlir %a = vector.broadcast %a_bc %res vector.contract %a_bc, %b, ... ``` into: ```mlir // Move the broadcast into vector.contract (by updating the indexing // maps) %res vector.contract %a, %b, ... ``` The main challenge is supporting cases where the pattern drops a leading unit dimension. For example: ```mlir func.func @contract_broadcast_unit_dim_reduction_masked( %arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> { %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> %result = vector.mask %mask { vector.contract { indexing_maps = [#map0, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add> } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> } : vector<1x8x8x4xi1> -> vector<8x8xi32> return %result : vector<8x8xi32> } ``` Here, the leading unit dimension is dropped. To handle this, the mask is cast to the correct shape using a `vector.shape_cast`: ```mlir func.func @contract_broadcast_unit_dim_reduction_masked( %arg0: vector<8x4xi32>, %arg1: vector<8x4xi32>, %arg2: vector<8x8xi32>, %arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> { %mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1> %res = vector.mask %mask_sc { vector.contract { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add> } %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> } : vector<8x8x4xi1> -> vector<8x8xi32> return %res : vector<8x8xi32> } ``` While this isn't ideal - since it introduces a `vector.shape_cast` that must be cleaned up later - it reflects the best we can do once the input reaches `CombineContractBroadcastMask`. A more robust solution may involve simplifying the input earlier. I am leaving that as a TODO for myself to explore this further. Posting this now to unblock downstream work. LIMITATIONS Currently, this pattern assumes: * Only leading dimensions are dropped in the mask. * All dropped dimensions must be unit-sized.
Author
Parents
Loading