[JIT] Separate GPU implementation of frozen_conv_add_relu_fusion.cpp (#68149)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68149
JIT optimization passes are part of the CPU-only build (i.e. necessary GPU flags are not passed in). This separates the implementation of frozen_conv_add_relu_fusion so that the GPU-enabled implementation is registered at runtime (if it is available)
ghstack-source-id: 143676384
Test Plan:
In the following script, conv_add_relu fusion is not observed without this change, but is observed when this change is added.
```
from typing import List, Optional
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))
self.add_tensor = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))
def forward(
self,
inp: torch.Tensor,
bias: Optional[torch.Tensor],
stride: List[int],
padding: List[int],
dilation: List[int],
groups: int,
):
# weight = torch.zeros((3, 3, 7, 7), device="cuda")
inp = inp.to("cuda")
conv_result = torch.conv2d(
inp, self.weight, bias, stride, padding, dilation, groups
)
add_result = conv_result.add_(self.add_tensor)
return add_result.relu_()
torch.jit.export
def make_prediction(self, inp: torch.Tensor):
bias = None
groups = 1
stride = (1, 1)
padding = (0, 0)
dilation = (1, 1)
return self.forward(inp, bias, stride, padding, dilation, groups)
if __name__ == "__main__":
# generate some sample input
groups = 1
channels_in = 3
channels_out = 3
kernel_size = (7, 7)
stride = (1, 1)
padding = (0, 0)
dilation = (1, 1)
inp = torch.rand((64, 3, 432, 432))
weight = torch.rand(
(channels_out, channels_in, kernel_size[0], kernel_size[1]), device="cuda"
)
bias = None
model = Model()
model.eval()
script = torch.jit.script(model)
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
print("~~~~ FORWARD ~~~~")
print(script.graph)
print("with preserved_attrs")
print(torch.sum(script.forward(inp, bias, stride, padding, dilation, groups)))
```
Reviewed By: cpuhrsch
Differential Revision: D32329330
fbshipit-source-id: c0f10da4b9540c588819efe3ec540baa0fae4b35