pytorch
4f9d2f74 - Port softplus activation to Aten(CPU+CUDA) (#30504)

Commit
4 years ago
Port softplus activation to Aten(CPU+CUDA) (#30504) Summary: VitalyFedyunin, This PR is about port Softplus activation to Aten: **Test script:** ``` import torch import torch.nn as nn import time torch.manual_seed(0) def _time(): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() device = "cpu" m = nn.Softplus() if torch.cuda.is_available(): device = "cuda" m = m.cuda() #warm up for n in [100, 10000]: input = torch.randn(128, n, requires_grad=True, device=device) grad_output = torch.ones(128, n, device=device) for i in range(1000): output = m(input) output.backward(grad_output) for n in [100, 10000]: input = torch.randn(128, n, requires_grad=True, device=device) grad_output = torch.ones(128, n, device=device) fwd_t = 0 bwd_t = 0 for i in range(10000): t1 = _time() output = m(input) t2 = _time() output.backward(grad_output) t3 = _time() fwd_t = fwd_t + (t2 -t1) bwd_t = bwd_t + (t3 - t2) fwd_avg = fwd_t / 10000 * 1000 bwd_avg = bwd_t / 10000 * 1000 print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)." % (n, fwd_avg, bwd_avg)) ``` Test Device: CPU: skx-8180, GPU: Tesla P40. Perfromance: Before: ``` GPU: input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.12 (ms). input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.18 (ms). CPU: input size(128, 100) forward time is 1.16 (ms); backwad avg time is 0.69 (ms). input size(128, 10000) forward time is 60.19 (ms); backwad avg time is 31.86 (ms). ``` After: ``` GPU: input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.11 (ms). input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.17 (ms). CPU: input size(128, 100) forward time is 0.43 (ms); backwad avg time is 0.16 (ms). input size(128, 10000) forward time is 1.65 (ms); backwad avg time is 0.83 (ms). ``` `OMP_NUM_THREADS=1:` ``` Before: input size(128, 100) forward time is 0.53 (ms); backwad avg time is 0.28 (ms). input size(128, 10000) forward time is 51.33 (ms); backwad avg time is 25.48 (ms). After: input size(128, 100) forward time is 0.44 (ms); backwad avg time is 0.16 (ms). input size(128, 10000) forward time is 42.05 (ms); backwad avg time is 13.97 (ms). ``` Fix https://github.com/pytorch/pytorch/issues/24633, https://github.com/pytorch/pytorch/issues/24634, https://github.com/pytorch/pytorch/issues/24766, https://github.com/pytorch/pytorch/issues/24767. Pull Request resolved: https://github.com/pytorch/pytorch/pull/30504 Differential Revision: D19274913 Pulled By: ezyang fbshipit-source-id: 21b29e8459dcba5a040cc68333887b45a858328e
Author
Parents
Loading