pytorch
93cce973 - turn on nhwc for more PW ops (#1570)

Commit
3 years ago
turn on nhwc for more PW ops (#1570) Turning permutatoin support on for a few pointwise ops: aten::softplus; aten::threshold; aten::where; aten::lerp Initial request coming from an example with poor perf for channels last input tensor. ``` def smelu(x: torch.Tensor, beta: float): zero=torch.empty((), dtype=x.dtype, device=x.device).fill_(0.) # to avoid synchronizing h2d transfer return torch.where(x >= beta, x, torch.where(x <= - beta, zero, (x+beta)*(x+beta)/(4*beta))) memory_format=torch.channels_last x = torch.randn(10240000, device="cuda").view(128, 80, 10, 100).to(memory_format=memory_format).detach().requires_grad_(False) scripted = torch.jit.script(smelu) smelu(x, 2.) gO = torch.rand_like(x) with torch.jit.fuser("fuser2"): for _ in range(10): x.grad = None out = scripted(x, 2.) ```
Author
Parents
Loading