pytorch
4a99b4f1 - enable Half for cat serial kernel (#96021)

Commit
1 year ago
enable Half for cat serial kernel (#96021) Summary: 1.31 x speedup. | | shape | before | after | | ------------ | ------------- | ------------ | ------------- | | half | 1024 * (100 + i) | 235.75 us | 179.11 us | Benchmark with ``` import torch import torch.utils.benchmark as benchmark def cat(*args, dim=0): return torch.cat(args, dim) tensors = [] for i in range(10): tensors.append(torch.rand(1024, 100 + i).half()) t0 = benchmark.Timer( stmt="cat(*tensors, dim=1)", setup="from __main__ import cat", globals={"tensors": tensors}, num_threads=1, ) ``` Test Plan: CI Differential Revision: D43810514 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96021 Approved by: https://github.com/ngimel, https://github.com/houseroad, https://github.com/jgong5
Author
Committer
Parents
Loading