pytorch
4d29b402 - torch.compile DTensor E2E (#105236)

Commit
1 year ago
torch.compile DTensor E2E (#105236) This PR updates DTensor to support torch.compile Cool stuff: there are some new tests in `test_dtensor.py` that show both the forward and backward graphs that we can send to inductor, when running a matmul with DTensor's. In particular, for this user code: ``` def fn(x, y): dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False) dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False) dt_out = torch.matmul(dt, dt2) dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()]) return dt_out.to_local() ``` We generate the following fw and backward graphs. Forward graph: ``` def forward(self, primals_1, primals_2): view = torch.ops.aten.view.default(primals_1, [2, 4]); primals_1 = None _to_copy = torch.ops.aten._to_copy.default(view, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view = None detach = torch.ops.aten.detach.default(_to_copy); _to_copy = None detach_1 = torch.ops.aten.detach.default(detach); detach = None view_1 = torch.ops.aten.view.default(primals_2, [4, 2]); primals_2 = None _to_copy_1 = torch.ops.aten._to_copy.default(view_1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view_1 = None detach_2 = torch.ops.aten.detach.default(_to_copy_1); _to_copy_1 = None detach_3 = torch.ops.aten.detach.default(detach_2); detach_2 = None detach_4 = torch.ops.aten.detach.default(detach_1) all_gather_into_tensor = torch.ops.c10d_functional.all_gather_into_tensor.default(detach_3, 'ptd:0', [0, 1], 2) wait_tensor = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None split = torch.ops.aten.split.Tensor(wait_tensor, 4); wait_tensor = None getitem = split[0] getitem_1 = split[1]; split = None cat = torch.ops.aten.cat.default([getitem, getitem_1], 1); getitem = getitem_1 = None detach_5 = torch.ops.aten.detach.default(cat); cat = None mm = torch.ops.aten.mm.default(detach_4, detach_5); detach_4 = detach_5 = None detach_6 = torch.ops.aten.detach.default(mm); mm = None detach_9 = torch.ops.aten.detach.default(detach_6); detach_6 = None detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None t = torch.ops.aten.t.default(detach_1); detach_1 = None detach_13 = torch.ops.aten.detach.default(t); t = None t_1 = torch.ops.aten.t.default(detach_3); detach_3 = None detach_15 = torch.ops.aten.detach.default(t_1); t_1 = None clone = torch.ops.aten.clone.default(detach_15, memory_format = torch.contiguous_format); detach_15 = None return [detach_10, detach_13, clone] ``` Backward graph: ``` def forward(self, detach_13, clone, tangents_1): detach_11 = torch.ops.aten.detach.default(tangents_1); tangents_1 = None detach_12 = torch.ops.aten.detach.default(detach_11); detach_11 = None mm_1 = torch.ops.aten.mm.default(detach_13, detach_12); detach_13 = None detach_14 = torch.ops.aten.detach.default(mm_1); mm_1 = None detach_16 = torch.ops.aten.detach.default(detach_12); detach_12 = None all_gather_into_tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor.default(clone, 'ptd:0', [0, 1], 2); clone = None wait_tensor_2 = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor_2); detach_17 = torch.ops.aten.detach.default(wait_tensor_2); wait_tensor_2 = None mm_2 = torch.ops.aten.mm.default(detach_16, detach_17); detach_16 = detach_17 = None detach_18 = torch.ops.aten.detach.default(mm_2); mm_2 = None split_1 = torch.ops.aten.split.Tensor(detach_14, 2, 1); detach_14 = None getitem_2 = split_1[0] getitem_3 = split_1[1]; split_1 = None cat_1 = torch.ops.aten.cat.default([getitem_2, getitem_3]); getitem_2 = getitem_3 = None reduce_scatter_tensor = torch.ops.c10d_functional.reduce_scatter_tensor.default(cat_1, 'SUM', 'ptd:0', [0, 1], 2); cat_1 = None wait_tensor_3 = torch.ops.c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None detach_19 = torch.ops.aten.detach.default(wait_tensor_3); wait_tensor_3 = None detach_20 = torch.ops.aten.detach.default(detach_19); detach_19 = None detach_21 = torch.ops.aten.detach.default(detach_20); detach_20 = None detach_22 = torch.ops.aten.detach.default(detach_21); detach_21 = None _to_copy_2 = torch.ops.aten._to_copy.default(detach_22, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_22 = None view_2 = torch.ops.aten.view.default(_to_copy_2, [8]); _to_copy_2 = None detach_23 = torch.ops.aten.detach.default(detach_18); detach_18 = None detach_24 = torch.ops.aten.detach.default(detach_23); detach_23 = None _to_copy_3 = torch.ops.aten._to_copy.default(detach_24, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_24 = None view_3 = torch.ops.aten.view.default(_to_copy_3, [8]); _to_copy_3 = None return [view_3, view_2] ``` Some of the stuff in this graph looks kinda of silly though (e.g. an unnecessary split() + cat(), and all the extra detach() calls). Stuff that's broken: - functionalization is pretty horribly broken. In particular, the original strategy I used in this stack was to have functionalization run **above** subclass desugaring. But that doesn't play well with with the way we want to compile DTensor. DTensor has a few API's like `.redistribute()`, `.to_local()`, and the `DTensor()` constructor, that we want to put directly into the graph so that we can compile them (e.g. redistribute() will desugar into collective ops). Doing this requires functionalization to run **underneath** the subclass though. I hacked around this for now, by forcing these functions to run functionalization first if they need to. - the backward test that I have is... wrong. The backward graph that we trace out looks kind of reasonable, but it gives incorrect gradients on one of the two inputs. This needs further debugging (presumably we should be able to stare at the graph and identify which part of it is wrong?). Pull Request resolved: https://github.com/pytorch/pytorch/pull/105236 Approved by: https://github.com/wanchaol
Author
Committer
Parents
Loading