pytorch
5bcbad76 - Fix perf regression introduced in #70943 (#78588)

Commit
2 years ago
Fix perf regression introduced in #70943 (#78588) `numel` is a too loose upper bound for `num_of_segments` and `num_of_partial_segments`. It causes perf regressions. This PR moves to a tighter upper bound. Benchmark with jupyter notebook: ```python import torch num_embeddings = 1024 embedding_dim = 512 e = torch.nn.Embedding(num_embeddings, embedding_dim).cuda() size = 1*1024*1024 i = torch.arange(size, device='cuda') % num_embeddings o = e(i) g = torch.randn_like(o) torch.cuda.synchronize() ``` ```python %%timeit o.backward(g, retain_graph=True) torch.cuda.synchronize() ``` Before #70943: 3.6 ms After #70943: 6.9 ms With this PR: 3.55 ms Pull Request resolved: https://github.com/pytorch/pytorch/pull/78588 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading