pytorch
8c96b436 - Remove opmath cast for im2col decomp (#121363)

Commit
2 years ago
Remove opmath cast for im2col decomp (#121363) It is unclear why opmath cast is needed for im2col decomp, given that the decomposition is mainly performing padding, slicing, indexing and shape manipulation. There is no need for performing these operations in a higher precision, and in doing so it requires more memory and yields less performance. Sample script to demonstrate inserted cast before this change ```python import torch from torch._decomp.decompositions import im2col def func(x): return torch.nn.functional.unfold( x, kernel_size=[3, 1], padding=[2, 0], dilation=1, stride=1 ) x = torch.rand(1, 1, 5, 5, dtype=torch.float16) eo = torch._dynamo.export( func, aten_graph=True, decomposition_table={torch.ops.aten.im2col.default: im2col} )(x) eo.graph_module.print_readable() ``` ``` class GraphModule(torch.nn.Module): def forward(self, x): arg0: "f16[1, 1, s0, s0]"; arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) arg0_1 = arg0 _to_copy: "f32[1, 1, s0, s0]" = torch.ops.aten._to_copy.default(arg0_1, dtype = torch.float32) ... constant_pad_nd: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.constant_pad_nd.default(_to_copy, [0, 0, 2, 2], 0.0); _to_copy = None ... slice_1: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.slice.Tensor(constant_pad_nd, 0, 0, 9223372036854775807); constant_pad_nd = None slice_2: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None index: "f32[1, 1, 3, s0 + 2, 1, s0]" = torch.ops.aten.index.Tensor(slice_2, [None, None, unsqueeze_5, add_3]); slice_2 = unsqueeze_5 = add_3 = None permute: "f32[1, 1, 3, 1, s0 + 2, s0]" = torch.ops.aten.permute.default(index, [0, 1, 2, 4, 3, 5]); index = None ... view: "f32[1, 3, s0**2 + 2*s0]" = torch.ops.aten.view.default(permute, [1, 3, mul]); permute = mul = None _to_copy_1: "f16[1, 3, s0**2 + 2*s0]" = torch.ops.aten._to_copy.default(view, dtype = torch.float16); view = None return pytree.tree_unflatten([_to_copy_1], self._out_spec) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121363 Approved by: https://github.com/lezcano
Author
Committer
Parents
Loading