llvm-project
077a796c - [mlir] Implement a memory-space cast bubbling-down transform (#159454)

Commit
214 days ago
[mlir] Implement a memory-space cast bubbling-down transform (#159454) This commit adds functionality to bubble down memory-space casts operations, allowing consumer operations to use the original memory-space rather than first casting to a different memory space. Changes: - Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations - Create a `MemorySpaceCastConsumerOpInterface` pass that identifies and bubbles down eligible casts - Add implementation for memref and vector operations to handle memory-space cast propagation - Add `bubbleDownCasts` method to relevant operations to support the fusion In particular, in the current implementation only memory-space casts into the default memory-space can be bubbled-down. Example: ```mlir func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32> %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32> %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32> %loaded = memref.load %collapsed[%c0] : memref<16xf32> %added = arith.addf %loaded, %arg2 : f32 memref.store %added, %collapsed[%c0] : memref<16xf32> %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32 return %collapsed : memref<16xf32> } // mlir-opt --bubble-down-memory-space-casts func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1> %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1> %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32> %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1> %1 = arith.addf %0, %arg2 : f32 memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1> %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32 return %memspacecast : memref<16xf32> } ``` --------- Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com> Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Author
Parents
Loading