[Mosaic GPU] Hand-roll a swizzling pattern to avoid bank conflicts in cross-warp reductions.
Previously to this, stores would only be able to target a subset of the banks
for a given warp, leading to bank conflicts. The swizzling pattern depends on
the number of banks used for a given store operation (assuming all 4 warps
store the data contiguously):
* If exactly 1 bank is hit, then there are no bank conflicts, and there is no
need to swizzle;
* If exactly 2 banks are hit, then our stores will span 256 bytes in total,
which we can split in 2 rows of 128 bytes (the length of SMEM banks). For
warps 0 and 1, stores in the first row (the first 16 lanes) target
even-numbered banks, and stores in the second second target odd-numbered
banks (and the opposite pattern applies to warps 2 and 3);
* If a multiple of 4 banks are hit, then we simply rotate the warp index by
`row_idx % 4`, ensuring that after every 4 rows, each warp has hit all the
banks.
Avoiding bank conflicts on stores is beneficial, but pales in comparison to
ensuring the subsequent loads are vectorized and without bank conflicts. Take
the `1xf32` case. An initial version of this change swizzled load indices
accordingly to the pattern above, forcing `LDS.128` instructions to be split up
into 4 (conflict-free) `LDS` instructions.
It turns out that, although the schedule went from having 4-way bank conflicts
on stores to none, splitting up the `LDS.128` instruction unexpectedly
increased latency significantly. The likely explanation is increased scheduler
(or instruction decode) pressure.
In order to work around this issue, we do not bother swizzling warp indices
on loads---which is not necessary so long that we intend to load all the
relevant data[^1].
For other data types (e.g. `1xf16`), LLVM was failing to figure out that the
loads could be vectorized, and also failed to figure out a conflict-free load
pattern. To remediate that, the present change rewrites the unvectorized load
loop to a single explicit vectorized load and a loop of slices---which achieves
the desired conflict-free, vectorized load schedule.
## Benchmarks on H100
| Test index | Dtype | Split | Vector length | Before | After | Diff | Diff % |
|------------|---------|------------------|---------------|--------|-------|-------|---------|
| 0 | float32 | fully_split | 1 | 2662 | 2469 | -193 | -7.25% |
| 1 | float32 | fully_split | 2 | 5968 | 5116 | -852 | -14.28% |
| 2 | float32 | fully_split | 4 | 27362 | 25820 | -1542 | -5.64% |
| 3 | float32 | half_warps_major | 1 | 3135 | 2400 | -735 | -23.44% |
| 4 | float32 | half_warps_major | 2 | 5942 | 3466 | -2476 | -41.67% |
| 5 | float32 | half_warps_major | 4 | 23270 | 17767 | -5503 | -23.65% |
| 6 | float32 | half_warps_minor | 1 | 3183 | 2422 | -761 | -23.91% |
| 7 | float32 | half_warps_minor | 2 | 5952 | 3503 | -2449 | -41.15% |
| 8 | float32 | half_warps_minor | 4 | 23176 | 17666 | -5510 | -23.77% |
| 9 | float16 | fully_split | 1 | 4635 | 2507 | -2128 | -45.91% |
| 10 | float16 | fully_split | 2 | 2666 | 2460 | -206 | -7.73% |
| 11 | float16 | fully_split | 4 | 5950 | 5064 | -886 | -14.89% |
| 12 | float16 | half_warps_major | 1 | 6814 | 6546 | -268 | -3.93% |
| 13 | float16 | half_warps_major | 2 | 3157 | 2429 | -728 | -23.06% |
| 14 | float16 | half_warps_major | 4 | 5893 | 3488 | -2405 | -40.81% |
| 15 | float16 | half_warps_minor | 1 | 6731 | 6550 | -181 | -2.69% |
| 16 | float16 | half_warps_minor | 2 | 3158 | 2417 | -741 | -23.46% |
| 17 | float16 | half_warps_minor | 4 | 5883 | 3491 | -2392 | -40.66% |
| 18 | int8 | fully_split | 1 | 4094 | 4101 | 7 | 0.17% |
| 19 | int8 | fully_split | 2 | 11200 | 10284 | -916 | -8.18% |
| 20 | int8 | fully_split | 4 | 23625 | 21910 | -1715 | -7.26% |
| 21 | int8 | half_warps_major | 1 | 4021 | 4023 | 2 | 0.05% |
| 22 | int8 | half_warps_major | 2 | 10625 | 9489 | -1136 | -10.69% |
| 23 | int8 | half_warps_major | 4 | 31334 | 23175 | -8159 | -26.04% |
| 24 | int8 | half_warps_minor | 1 | 4021 | 3998 | -23 | -0.57% |
| 25 | int8 | half_warps_minor | 2 | 10443 | 9339 | -1104 | -10.57% |
| 26 | int8 | half_warps_minor | 4 | 25375 | 22434 | -2941 | -11.59% |
## Benchmarks on B200
| Test index | Dtype | Split | Vector length | Before | After | Diff | Diff % |
|------------|---------|------------------|---------------|--------|-------|-------|---------|
| 0 | float32 | fully_split | 1 | 1565 | 1562 | -3 | -0.19% |
| 1 | float32 | fully_split | 2 | 3576 | 2849 | -727 | -20.33% |
| 2 | float32 | fully_split | 4 | 14052 | 11831 | -2221 | -15.81% |
| 3 | float32 | half_warps_major | 1 | 1958 | 1647 | -311 | -15.88% |
| 4 | float32 | half_warps_major | 2 | 3500 | 1793 | -1707 | -48.77% |
| 5 | float32 | half_warps_major | 4 | 8848 | 4776 | -4072 | -46.02% |
| 6 | float32 | half_warps_minor | 1 | 1951 | 1635 | -316 | -16.20% |
| 7 | float32 | half_warps_minor | 2 | 3502 | 1793 | -1709 | -48.80% |
| 8 | float32 | half_warps_minor | 4 | 8857 | 4810 | -4047 | -45.69% |
| 9 | float16 | fully_split | 1 | 4605 | 2607 | -1998 | -43.39% |
| 10 | float16 | fully_split | 2 | 2331 | 2340 | 9 | 0.39% |
| 11 | float16 | fully_split | 4 | 5100 | 4378 | -722 | -14.16% |
| 12 | float16 | half_warps_major | 1 | 6840 | 6592 | -248 | -3.63% |
| 13 | float16 | half_warps_major | 2 | 2708 | 2386 | -322 | -11.89% |
| 14 | float16 | half_warps_major | 4 | 5064 | 3294 | -1770 | -34.95% |
| 15 | float16 | half_warps_minor | 1 | 6741 | 6615 | -126 | -1.87% |
| 16 | float16 | half_warps_minor | 2 | 2714 | 2396 | -318 | -11.72% |
| 17 | float16 | half_warps_minor | 4 | 5049 | 3292 | -1757 | -34.80% |
| 18 | int8 | fully_split | 1 | 4288 | 4273 | -15 | -0.35% |
| 19 | int8 | fully_split | 2 | 12471 | 12257 | -214 | -1.72% |
| 20 | int8 | fully_split | 4 | 26863 | 27145 | 282 | 1.05% |
| 21 | int8 | half_warps_major | 1 | 4250 | 4225 | -25 | -0.59% |
| 22 | int8 | half_warps_major | 2 | 11014 | 11167 | 153 | 1.39% |
| 23 | int8 | half_warps_major | 4 | 27503 | 27368 | -135 | -0.49% |
| 24 | int8 | half_warps_minor | 1 | 4209 | 4210 | 1 | 0.02% |
| 25 | int8 | half_warps_minor | 2 | 10931 | 11033 | 102 | 0.93% |
| 26 | int8 | half_warps_minor | 4 | 27482 | 27479 | -3 | -0.01% |
There are a couple of outstanding questions in the benchmark, namely:
1. Why are there configurations where expanding the `vector_length` linearly
yields much worse runtime than expected?
2. Why do improvements on `int8` on H100 not translate to similar improvements
on B200?
We defer answering these questions to later, since these use cases are less
common.
[^1]: this makes the reduction order lane-dependent, but since data
distribution and index computations are deterministic, this is not an issue.
PiperOrigin-RevId: 859140307