[CUDA] Transpose3DImpl Supporting more Cases (#13611)
CUDA's Transpose3DImpl is to transpose [batch, m, n] to [batch, n, m].
Currently it requires both m and n can be divided by 32 or 16. If it's
not this case, the compute will fallback to general implementation,
which is slow. This PR is to remove the limitation.
Profiling in V100 using below size of tensors, got the cycles number
from Nsight Compute:
| Old | New
-- | -- | --
[3072,64,512] | 760793 | 727140
[3072,16,2048] | 854303 | 851146
[3072,2048,12] | 986924 | 737884
[3072,1024,24] | 1212427 | 495117
It shows that even we added extra IF statements to the kernel
implementation, it has nearly no impact to the old version (case 1 and
2). And for case 3 and 4 which will fallback to general implementation
before, it's much faster.
Above data was collected using FP16 tensors, similar results was
observed for float tensors.
This PR is to enhance the perf of ORT training of Huggingface's XLNet
model which has[8,1024,1024,12].permute(0,3,1,2).