pytorch
b8fa41be - Support map autograd and pytree in/out (#100494)

Commit
1 year ago
Support map autograd and pytree in/out (#100494) This PR adds autograd and pytree support for map operator. Implementation-wise: 1. We temporarily make two HigherOrderOperators, "map" and "map_impl": - "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators. - "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked. 2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient. Example traced graphs for map operators: ### Case 1: simple f and autograd ```python def f(x, y): return x + y def g(xs, y): out = control_flow.map(f, xs, y) return torch.autograd.grad(out, (xs, y), torch.ones_like(out)) gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)) # gm.print_readable() produces following: class g(torch.nn.Module): def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]): # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1); body_graph_0 = None getitem: f32[3, s1, s2] = map_impl[0]; map_impl = None ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False) is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1); body_graph_1 = xs_1 = ones_like = None getitem_1 = map_impl_1[0] getitem_2: f32[3, s1, s2] = map_impl_1[1] getitem_3: f32[3, s2] = map_impl_1[2]; map_impl_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True); getitem_3 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return (getitem_2, view) class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]): # No stacktrace found for following nodes add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None return [add] class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]): # No stacktrace found for following nodes add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1); arg1_1 = None is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1); add = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True) sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0); arg3_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return [None, arg2_1, view] ``` ### Case 2: list input/output f and autograd ```python def f(x, y): return [x[0].cos() + y.sin(), x[1].sin() * y.cos()] def g(xs, y): out = control_flow.map(f, xs, y) flat_out, _ = pytree.tree_flatten(out) flat_inp, _ = pytree.tree_flatten((xs, y)) requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad] return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]) gm = make_fx(g, tracing_mode="symbolic")( [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)], torch.ones(5, requires_grad=True)) # gm.print_readable() produces following: class g(torch.nn.Module): def forward(self, xs, y): xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec) # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1); body_graph_0 = None getitem: f32[3, s1, s2] = map_impl[0] getitem_1: f32[3, s1, s2] = map_impl[1]; map_impl = None ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False) ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False) is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1); getitem_1 = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1); body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None getitem_2 = map_impl_1[0] getitem_3 = map_impl_1[1] getitem_4: f32[3, s1, s2] = map_impl_1[2] getitem_5: f32[3, s2] = map_impl_1[3]; map_impl_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True); getitem_5 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None return pytree.tree_unflatten([getitem_4, view], self._out_spec) class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]): # No stacktrace found for following nodes cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin: f32[s2] = torch.ops.aten.sin.default(arg3_1) add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1); arg2_1 = None cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1); arg3_1 = None mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1); sin_1 = cos_1 = None return [add, mul] class <lambda>(torch.nn.Module): def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]): # No stacktrace found for following nodes cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin: f32[s2] = torch.ops.aten.sin.default(arg5_1) add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1) cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1) mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1) is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1); add = None is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1); mul = None mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1); sin_1 = None mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1); arg4_1 = cos_1 = None sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True); mul_1 = None sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0) view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = None # sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1) neg: f32[s2] = torch.ops.aten.neg.default(sin_2); sin_2 = None mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg); view = neg = None cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1); arg2_1 = None mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2); mul_2 = cos_2 = None sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True); arg3_1 = None view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]); sum_2 = sym_size = None cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1); arg5_1 = None mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3); view_1 = cos_3 = None add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5); mul_3 = mul_5 = None return [None, None, mul_4, add_1] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/100494 Approved by: https://github.com/zou3519
Author
Committer
Parents
Loading