enable bf16 for cat serial kernel (#54674)
Summary:
cat 10 2-D tensors at dim=1
| | shape | serial kernel | copy kernel |
| ------------ | ------------- | ------------ | ------------- |
| fp32 | 1024 * 16k | 105.45 ms | 102.41 ms |
| fp32 | 1024 * (100 + i) | 324.75 us | 448.66 us |
| bf16 | 1024 * 16k | 49.82 ms | 51.39 ms |
| bf16 | 1024 * (100 + i) | 164.74 us | 244.64 us |
i = {0, ..., 9}
benchmark code
```
import torch
import torch.utils.benchmark as benchmark
def cat(*args, dim=0):
return torch.cat(args, dim)
tensors = []
for i in range(10):
tensors.append(torch.rand(1024, 16 *1024))
# tensors.append(torch.rand(1024, 16 *1024).bfloat16())
# tensors.append(torch.rand(1024, 100 + i))
# tensors.append(torch.rand(1024, 100 + i).bfloat16())
t0 = benchmark.Timer(
stmt='cat(*tensors, dim=1)',
setup='from __main__ import cat',
globals={'tensors': tensors},
num_threads=1)
print(t0.blocked_autorange())
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54674
Reviewed By: ailzhang
Differential Revision: D27325347
Pulled By: heitorschueroff
fbshipit-source-id: 7a0f4bf8d92dbf8e725fdd2e8a2c901188811d6f