Dynamically assign number of threads in innerdim scan (#103435)
This is the continuation of optimizing inner-dimension scan operations (`torch.cumsum`, `torch.cumprod`, `torch.logcumsumexp`) by dynamically setting the number of threads based on the input shape from #103314.
What I found that just setting the number of x-threads and y-threads following the ratio of the tensor's shape works quite well (with some clamping).
Here is the speed-up of this PR, compared to `2.0.0+cu118` (not compared to #103314) using A100 with 40GB memory (up to 23x faster):
```
2 8 32 128 512 1024 2048 4096 8096 16348 65536 262144 1048576
2: 1.07(4) 1.02(5) 1.01(6) 1.07(7) 2.16(8) 4.94(9) 8.71(9) 11.00(9) 12.99(9) 14.77(9) 16.41(9) 16.81(9) 16.97(9)
8: 1.20(4) 1.00(4) 1.01(5) 1.08(6) 2.85(7) 4.90(8) 6.34(8) 11.76(9) 13.86(9) 15.26(9) 16.96(9) 17.45(9) 19.75(9)
32: 1.08(4) 1.00(4) 1.00(4) 1.23(5) 2.48(6) 4.23(7) 5.04(7) 9.16(8) 10.11(8) 18.72(9) 20.64(9) 23.13(9) 23.50(9)
128: 1.09(4) 1.02(4) 1.03(4) 1.02(4) 1.64(5) 2.84(6) 3.08(6) 5.61(7) 5.86(7) 10.72(8) 19.22(9) 19.75(9) 19.97(9)
512: 1.06(4) 1.14(4) 1.01(4) 1.10(4) 1.02(4) 1.78(5) 1.85(5) 3.26(6) 3.34(6) 5.56(7) 8.56(8) 9.55(9) 9.62(9)
1024: 1.21(4) 1.22(4) 1.20(4) 1.06(4) 1.03(4) 1.05(4) 1.81(5) 1.86(5) 3.06(6) 3.12(6) 4.76(7) 5.20(8) 5.56(9)
2048: 1.04(4) 0.88(4) 1.00(4) 1.01(4) 1.02(4) 1.03(4) 1.02(4) 1.72(5) 1.73(5) 2.62(6) 2.86(7) 3.06(8) --------
4096: 1.02(4) 1.12(4) 0.98(4) 1.60(4) 1.16(4) 1.09(4) 1.10(4) 1.10(4) 1.74(5) 1.75(5) 1.86(6) 2.00(7) --------
8096: 1.03(4) 1.00(4) 1.00(4) 1.16(4) 1.17(4) 1.17(4) 1.18(4) 1.18(4) 1.18(4) 1.27(5) 1.43(6) -------- --------
16348: 1.02(4) 1.15(4) 1.11(4) 1.17(4) 1.12(4) 1.11(4) 1.13(4) 1.12(4) 1.11(4) 1.08(4) 1.32(5) -------- --------
65536: 1.17(4) 1.17(4) 1.16(4) 1.15(4) 1.12(4) 1.12(4) 1.12(4) 1.10(4) 1.10(4) 1.07(4) -------- -------- --------
262144: 1.20(4) 1.20(4) 1.08(4) 1.13(4) 1.10(4) 1.09(4) 1.10(4) 1.08(4) -------- -------- -------- -------- --------
1048576: 1.21(4) 1.14(4) 1.10(4) 1.13(4) 1.09(4) 1.08(4) -------- -------- -------- -------- -------- -------- --------
```
The first row is the innermost dimension, the first column is the outermost dimension (i.e. the batch size).
The float numbers are the speed up while the integers within the brackets are the log2 of number of x-threads.
The blank cells (the ones with dashes) are not compared because of my GPU's memory limitation.
There are some slowdowns that I observed (like `(2048, 8)` and `(4096, 32)`). The slowdown is because in this PR, the scan loop (the one I use with Sklansky) is not optimized by the compiler due to dynamic number of iterations (it is `log2(num_threads_x)`), while in the previous version, the scan loop can be unrolled and optimized by the compiler due to fixed number of iterations.
That's why I slightly modified the operations within the scan loop to use bit operations in order to compensate for this slowdown.
The most significant acceleration comes from the tensors with relatively small batch size (<= 4096) and with very long sequence.
As the batch size increases, the speed up is not that significant because the previous implementation is most likely to be optimized.
NOTE: I haven't optimized scan dim with indices, it could come in another PR.
As for the build time, I tried not to write more templated functions than necessary.
I will report the build time when I already have the numbers.
UPDATE: I compared the build time when I changed ScanUtils.cuh only. In `main` branch, it took 4m2s, while in this PR, it took 3m39s.
What do you think, @ngimel?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103435
Approved by: https://github.com/ngimel