jax
053846ca - [Mosaic GPU] Hand-roll a swizzling pattern to avoid bank conflicts in cross-warp reductions.

Commit
26 days ago
[Mosaic GPU] Hand-roll a swizzling pattern to avoid bank conflicts in cross-warp reductions. Previously to this, stores would only be able to target a subset of the banks for a given warp, leading to bank conflicts. The swizzling pattern depends on the number of banks used for a given store operation (assuming all 4 warps store the data contiguously): * If exactly 1 bank is hit, then there are no bank conflicts, and there is no need to swizzle; * If exactly 2 banks are hit, then our stores will span 256 bytes in total, which we can split in 2 rows of 128 bytes (the length of SMEM banks). For warps 0 and 1, stores in the first row (the first 16 lanes) target even-numbered banks, and stores in the second second target odd-numbered banks (and the opposite pattern applies to warps 2 and 3); * If a multiple of 4 banks are hit, then we simply rotate the warp index by `row_idx % 4`, ensuring that after every 4 rows, each warp has hit all the banks. Avoiding bank conflicts on stores is beneficial, but pales in comparison to ensuring the subsequent loads are vectorized and without bank conflicts. Take the `1xf32` case. An initial version of this change swizzled load indices accordingly to the pattern above, forcing `LDS.128` instructions to be split up into 4 (conflict-free) `LDS` instructions. It turns out that, although the schedule went from having 4-way bank conflicts on stores to none, splitting up the `LDS.128` instruction unexpectedly increased latency significantly. The likely explanation is increased scheduler (or instruction decode) pressure. In order to work around this issue, we do not bother swizzling warp indices on loads---which is not necessary so long that we intend to load all the relevant data[^1]. For other data types (e.g. `1xf16`), LLVM was failing to figure out that the loads could be vectorized, and also failed to figure out a conflict-free load pattern. To remediate that, the present change rewrites the unvectorized load loop to a single explicit vectorized load and a loop of slices---which achieves the desired conflict-free, vectorized load schedule. ## Benchmarks on H100 | Test index | Dtype | Split | Vector length | Before | After | Diff | Diff % | |------------|---------|------------------|---------------|--------|-------|-------|---------| | 0 | float32 | fully_split | 1 | 2662 | 2469 | -193 | -7.25% | | 1 | float32 | fully_split | 2 | 5968 | 5116 | -852 | -14.28% | | 2 | float32 | fully_split | 4 | 27362 | 25820 | -1542 | -5.64% | | 3 | float32 | half_warps_major | 1 | 3135 | 2400 | -735 | -23.44% | | 4 | float32 | half_warps_major | 2 | 5942 | 3466 | -2476 | -41.67% | | 5 | float32 | half_warps_major | 4 | 23270 | 17767 | -5503 | -23.65% | | 6 | float32 | half_warps_minor | 1 | 3183 | 2422 | -761 | -23.91% | | 7 | float32 | half_warps_minor | 2 | 5952 | 3503 | -2449 | -41.15% | | 8 | float32 | half_warps_minor | 4 | 23176 | 17666 | -5510 | -23.77% | | 9 | float16 | fully_split | 1 | 4635 | 2507 | -2128 | -45.91% | | 10 | float16 | fully_split | 2 | 2666 | 2460 | -206 | -7.73% | | 11 | float16 | fully_split | 4 | 5950 | 5064 | -886 | -14.89% | | 12 | float16 | half_warps_major | 1 | 6814 | 6546 | -268 | -3.93% | | 13 | float16 | half_warps_major | 2 | 3157 | 2429 | -728 | -23.06% | | 14 | float16 | half_warps_major | 4 | 5893 | 3488 | -2405 | -40.81% | | 15 | float16 | half_warps_minor | 1 | 6731 | 6550 | -181 | -2.69% | | 16 | float16 | half_warps_minor | 2 | 3158 | 2417 | -741 | -23.46% | | 17 | float16 | half_warps_minor | 4 | 5883 | 3491 | -2392 | -40.66% | | 18 | int8 | fully_split | 1 | 4094 | 4101 | 7 | 0.17% | | 19 | int8 | fully_split | 2 | 11200 | 10284 | -916 | -8.18% | | 20 | int8 | fully_split | 4 | 23625 | 21910 | -1715 | -7.26% | | 21 | int8 | half_warps_major | 1 | 4021 | 4023 | 2 | 0.05% | | 22 | int8 | half_warps_major | 2 | 10625 | 9489 | -1136 | -10.69% | | 23 | int8 | half_warps_major | 4 | 31334 | 23175 | -8159 | -26.04% | | 24 | int8 | half_warps_minor | 1 | 4021 | 3998 | -23 | -0.57% | | 25 | int8 | half_warps_minor | 2 | 10443 | 9339 | -1104 | -10.57% | | 26 | int8 | half_warps_minor | 4 | 25375 | 22434 | -2941 | -11.59% | ## Benchmarks on B200 | Test index | Dtype | Split | Vector length | Before | After | Diff | Diff % | |------------|---------|------------------|---------------|--------|-------|-------|---------| | 0 | float32 | fully_split | 1 | 1565 | 1562 | -3 | -0.19% | | 1 | float32 | fully_split | 2 | 3576 | 2849 | -727 | -20.33% | | 2 | float32 | fully_split | 4 | 14052 | 11831 | -2221 | -15.81% | | 3 | float32 | half_warps_major | 1 | 1958 | 1647 | -311 | -15.88% | | 4 | float32 | half_warps_major | 2 | 3500 | 1793 | -1707 | -48.77% | | 5 | float32 | half_warps_major | 4 | 8848 | 4776 | -4072 | -46.02% | | 6 | float32 | half_warps_minor | 1 | 1951 | 1635 | -316 | -16.20% | | 7 | float32 | half_warps_minor | 2 | 3502 | 1793 | -1709 | -48.80% | | 8 | float32 | half_warps_minor | 4 | 8857 | 4810 | -4047 | -45.69% | | 9 | float16 | fully_split | 1 | 4605 | 2607 | -1998 | -43.39% | | 10 | float16 | fully_split | 2 | 2331 | 2340 | 9 | 0.39% | | 11 | float16 | fully_split | 4 | 5100 | 4378 | -722 | -14.16% | | 12 | float16 | half_warps_major | 1 | 6840 | 6592 | -248 | -3.63% | | 13 | float16 | half_warps_major | 2 | 2708 | 2386 | -322 | -11.89% | | 14 | float16 | half_warps_major | 4 | 5064 | 3294 | -1770 | -34.95% | | 15 | float16 | half_warps_minor | 1 | 6741 | 6615 | -126 | -1.87% | | 16 | float16 | half_warps_minor | 2 | 2714 | 2396 | -318 | -11.72% | | 17 | float16 | half_warps_minor | 4 | 5049 | 3292 | -1757 | -34.80% | | 18 | int8 | fully_split | 1 | 4288 | 4273 | -15 | -0.35% | | 19 | int8 | fully_split | 2 | 12471 | 12257 | -214 | -1.72% | | 20 | int8 | fully_split | 4 | 26863 | 27145 | 282 | 1.05% | | 21 | int8 | half_warps_major | 1 | 4250 | 4225 | -25 | -0.59% | | 22 | int8 | half_warps_major | 2 | 11014 | 11167 | 153 | 1.39% | | 23 | int8 | half_warps_major | 4 | 27503 | 27368 | -135 | -0.49% | | 24 | int8 | half_warps_minor | 1 | 4209 | 4210 | 1 | 0.02% | | 25 | int8 | half_warps_minor | 2 | 10931 | 11033 | 102 | 0.93% | | 26 | int8 | half_warps_minor | 4 | 27482 | 27479 | -3 | -0.01% | There are a couple of outstanding questions in the benchmark, namely: 1. Why are there configurations where expanding the `vector_length` linearly yields much worse runtime than expected? 2. Why do improvements on `int8` on H100 not translate to similar improvements on B200? We defer answering these questions to later, since these use cases are less common. [^1]: this makes the reduction order lane-dependent, but since data distribution and index computations are deterministic, this is not an issue. PiperOrigin-RevId: 859140307
Author
Parents
Loading