llvm-project
c41286af - [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)

Commit
1 year ago
[mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813) **Description** The documentation of `transform.structured.tile_using_forall` says: _"It is the user’s responsibility to ensure that num_threads/tile_sizes is a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case)."_ In other words, tiling a non-parallel dimension would generate code with data races which is not safe to parallelize. For example, consider this example (included in the tests in this PR): ``` func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { %0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) { %1 = affine.min #map(%arg2) %2 = affine.max #map1(%1) %3 = affine.apply #map2(%arg2) %extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor<?x300x8xf32> %4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor<?x300x8xf32>) outs(%arg3 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %5 = arith.addf %in, %out : f32 linalg.yield %5 : f32 } -> tensor<300x8xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32> } } return %0 : tensor<300x8xf32> } ``` We can easily see that this is not safe to parallelize because all threads would be writing to the same position in `%arg3` (in the `scf.forall.in_parallel`. This PR detects wether it's safe to `tile_using_forall` and emits a warning in the case it is not. **Brief explanation** It first generates a vector of affine expressions representing the tile values and stores it in `dimExprs`. These affine expressions are compared with the affine expressions coming from the results of the affine map of each output in the linalg op. So going back to the previous example, the original transform is: ``` #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { // expected-warning@+1 {{tiling is not thread safe at axis #0}} %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %1 = arith.addf %in, %out : f32 linalg.yield %1 : f32 } -> tensor<300x8xf32> return %0 : tensor<300x8xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } ``` The `num_threads` attribute would be represented as `(d0)`. Because the linalg op has only one output (`arg1`) it would only check against the results of `#map1`, which are `(d1, d2)`. The idea is to check that all affine expressions in `dimExprs` are present in the output affine map. In this example, `d0` is not in `(d1, d2)`, so tiling that axis is considered not thread safe.
Author
Parents
Loading