pytorch
d690a596 - Fast path binary ops in fake tensor (#94047)

Commit
1 year ago
Fast path binary ops in fake tensor (#94047) Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.) The implementation here is based off of https://github.com/pytorch/pytorch/pull/93118/ but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last). * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1)) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94047 Approved by: https://github.com/eellison
Author
Committer
Parents
Loading