torch.multinomial : fast-path for replacement=False (#39636)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/39624 #11931
Based on the example by RobertoLat
https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
**Fast-path is not taken on CPU for `Half` as `log` doesn't support it.**
Benchmark with same build settings on same system.
gcc : version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04)
CUDA : 10.1
GPU : 1050ti
```python
import time
import torch
import numpy as np
for n, t in [(500_000, 10),
(1_000_000, 10)]:
for dtype in (torch.half, torch.float, torch.double):
# Input Setup
p = torch.from_numpy(np.random.rand(n)).to(dtype)
want = 1000
print(f'torch.multinomial(a) a.numel() == {n} for {t} times {dtype}')
start = time.time()
# Iterate
for _ in range(t):
torch.multinomial(p, want, replacement=False)
print(f'Took:', time.time() - start)
print('****' * 10)
for n, t in [(50_000, 100),
(100_000, 100)]:
for dtype in (torch.half, torch.float, torch.double):
# Input Setup
p = torch.rand(n, device='cuda', dtype=dtype)
want = 1000
print(f'torch.multinomial(a) a.numel() == {n} for {t} times {dtype}')
start = time.time()
# torch.cuda.synchronize()
# Iterate
for _ in range(t):
torch.multinomial(p, want, replacement=False)
# torch.cuda.synchronize()
print(f'CUDA Took:', time.time() - start)
```
Before:
```
torch.multinomial(a) a.numel() == 500000 for 10 times torch.float16
Took: 80.64455389976501
torch.multinomial(a) a.numel() == 500000 for 10 times torch.float32
Took: 3.7778031826019287
torch.multinomial(a) a.numel() == 500000 for 10 times torch.float64
Took: 5.045570611953735
torch.multinomial(a) a.numel() == 1000000 for 10 times torch.float16
Took: 161.53191947937012
torch.multinomial(a) a.numel() == 1000000 for 10 times torch.float32
Took: 7.640851736068726
torch.multinomial(a) a.numel() == 1000000 for 10 times torch.float64
Took: 10.399673461914062
****************************************
torch.multinomial(a) a.numel() == 50000 for 100 times torch.float16
CUDA Took: 4.873984098434448
torch.multinomial(a) a.numel() == 50000 for 100 times torch.float32
CUDA Took: 4.713594436645508
torch.multinomial(a) a.numel() == 50000 for 100 times torch.float64
CUDA Took: 11.167185068130493
torch.multinomial(a) a.numel() == 100000 for 100 times torch.float16
CUDA Took: 7.195427417755127
torch.multinomial(a) a.numel() == 100000 for 100 times torch.float32
CUDA Took: 7.669712066650391
torch.multinomial(a) a.numel() == 100000 for 100 times torch.float64
CUDA Took: 20.20938801765442
```
After:
```
torch.multinomial(a) a.numel() == 500000 for 10 times torch.float16
Took: 80.6487455368042
torch.multinomial(a) a.numel() == 500000 for 10 times torch.float32
Took: 0.0663309097290039
torch.multinomial(a) a.numel() == 500000 for 10 times torch.float64
Took: 0.09588909149169922
torch.multinomial(a) a.numel() == 1000000 for 10 times torch.float16
Took: 161.60748076438904
torch.multinomial(a) a.numel() == 1000000 for 10 times torch.float32
Took: 0.13187885284423828
torch.multinomial(a) a.numel() == 1000000 for 10 times torch.float64
Took: 0.17609834671020508
****************************************
torch.multinomial(a) a.numel() == 50000 for 100 times torch.float16
CUDA Took: 0.007131099700927734
torch.multinomial(a) a.numel() == 50000 for 100 times torch.float32
CUDA Took: 0.022255420684814453
torch.multinomial(a) a.numel() == 50000 for 100 times torch.float64
CUDA Took: 0.0323028564453125
torch.multinomial(a) a.numel() == 100000 for 100 times torch.float16
CUDA Took: 0.04995012283325195
torch.multinomial(a) a.numel() == 100000 for 100 times torch.float32
CUDA Took: 0.04948878288269043
torch.multinomial(a) a.numel() == 100000 for 100 times torch.float64
CUDA Took: 0.05495333671569824
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39636
Differential Revision: D21925406
Pulled By: ngimel
fbshipit-source-id: f2ee5148fa7dd88e018c461ced0e2361c3a43796