llvm-project
e2870454 - [MLIR][Shard] Fix three bugs in ND mesh resharding in Partition pass (#189241)

Commit
10 days ago
[MLIR][Shard] Fix three bugs in ND mesh resharding in Partition pass (#189241) A new MoveLastSplitAxisPattern class handles the case where the last grid axis of one tensor dimension is moved to the front of another tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]]. The three bugs fixed are: 1. detectMoveLastSplitAxisInResharding: compared source.back() with target.back() instead of target.front(), preventing the pattern from being detected for resharding like [[0,1],[2]] -> [[0],[1,2]]. 2. targetShardingInMoveLastAxis: axes were appended with push_back but should be inserted at the front, producing wrong split_axes order. 3. handlePartialAxesDuringResharding: a copy_if wrote results into the wrong output variable (addressed structurally by the clean implementation). Fixes #136117 Assisted-by: Claude Code
Author
Parents
Loading