Support map autograd and pytree in/out (#100494)
This PR adds autograd and pytree support for map operator.
Implementation-wise:
1. We temporarily make two HigherOrderOperators, "map" and "map_impl":
- "map" is user-facing. Currently, it unwraps the pytrees in inputs and create a flat_fn for it. Dynamo currently cannot deal with pytree.tree_flatten and pytree.tree_unflatten, we therefore make it a HigherOrderOperator to trigger dynamo logic of handling HigherOrderOperators.
- "map_impl" is the actual operator that works with the rest of torch subsystems such as functionalization, make_fx. It accepts flattend arguments, and a num_mapped_args integer denoting how many of the flattend arguments need to mapped i.e. their first dimension will be unstacked.
2. We create the forward and backward graph in autograd key and call torch.autograd.Function. Currently, the backward graph is recomputation-based and we need to partition the joint graph in the future to be more efficient.
Example traced graphs for map operators:
### Case 1: simple f and autograd
```python
def f(x, y):
return x + y
def g(xs, y):
out = control_flow.map(f, xs, y)
return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True))
# gm.print_readable() produces following:
class g(torch.nn.Module):
def forward(self, xs_1: f32[3, s1, s2], y_1: f32[s2]):
# No stacktrace found for following nodes
body_graph_0 = self.body_graph_0
map_impl = torch.ops.map_impl(body_graph_0, 1, xs_1, y_1); body_graph_0 = None
getitem: f32[3, s1, s2] = map_impl[0]; map_impl = None
ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None
body_graph_1 = self.body_graph_1
map_impl_1 = torch.ops.map_impl(body_graph_1, 2, xs_1, ones_like, y_1); body_graph_1 = xs_1 = ones_like = None
getitem_1 = map_impl_1[0]
getitem_2: f32[3, s1, s2] = map_impl_1[1]
getitem_3: f32[3, s2] = map_impl_1[2]; map_impl_1 = None
sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_3, [0], True); getitem_3 = None
sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None
view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None
return (getitem_2, view)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s2]):
# No stacktrace found for following nodes
add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None
return [add]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
# No stacktrace found for following nodes
add: f32[s1, s2] = torch.ops.aten.add.Tensor(arg1_1, arg3_1); arg1_1 = None
is_same_size = torch.ops.aten.is_same_size.default(add, arg2_1); add = None
sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg2_1, [0], True)
sym_size: Sym(s2) = torch.ops.aten.sym_size(arg3_1, 0); arg3_1 = None
view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None
return [None, arg2_1, view]
```
### Case 2: list input/output f and autograd
```python
def f(x, y):
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
def g(xs, y):
out = control_flow.map(f, xs, y)
flat_out, _ = pytree.tree_flatten(out)
flat_inp, _ = pytree.tree_flatten((xs, y))
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
return torch.autograd.grad(flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out])
gm = make_fx(g, tracing_mode="symbolic")(
[torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
torch.ones(5, requires_grad=True))
# gm.print_readable() produces following:
class g(torch.nn.Module):
def forward(self, xs, y):
xs_1: f32[3, s1, s2], xs_2: f32[3, s1, s2], y_1: f32[s2], = fx_pytree.tree_flatten_spec([xs, y], self._in_spec)
# No stacktrace found for following nodes
body_graph_0 = self.body_graph_0
map_impl = torch.ops.map_impl(body_graph_0, 2, xs_1, xs_2, y_1); body_graph_0 = None
getitem: f32[3, s1, s2] = map_impl[0]
getitem_1: f32[3, s1, s2] = map_impl[1]; map_impl = None
ones_like: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem, pin_memory = False)
ones_like_1: f32[3, s1, s2] = torch.ops.aten.ones_like.default(getitem_1, pin_memory = False)
is_same_size = torch.ops.aten.is_same_size.default(getitem, ones_like); getitem = None
is_same_size_1 = torch.ops.aten.is_same_size.default(getitem_1, ones_like_1); getitem_1 = None
body_graph_1 = self.body_graph_1
map_impl_1 = torch.ops.map_impl(body_graph_1, 4, xs_1, xs_2, ones_like, ones_like_1, y_1); body_graph_1 = xs_1 = xs_2 = ones_like = ones_like_1 = None
getitem_2 = map_impl_1[0]
getitem_3 = map_impl_1[1]
getitem_4: f32[3, s1, s2] = map_impl_1[2]
getitem_5: f32[3, s2] = map_impl_1[3]; map_impl_1 = None
sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(getitem_5, [0], True); getitem_5 = None
sym_size: Sym(s2) = torch.ops.aten.sym_size(y_1, 0); y_1 = None
view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = sym_size = None
return pytree.tree_unflatten([getitem_4, view], self._out_spec)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s2]):
# No stacktrace found for following nodes
cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None
sin: f32[s2] = torch.ops.aten.sin.default(arg3_1)
add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1); arg2_1 = None
cos_1: f32[s2] = torch.ops.aten.cos.default(arg3_1); arg3_1 = None
mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1); sin_1 = cos_1 = None
return [add, mul]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1, arg1_1: f32[s1, s2], arg2_1: f32[s1, s2], arg3_1: f32[s1, s2], arg4_1: f32[s1, s2], arg5_1: f32[s2]):
# No stacktrace found for following nodes
cos: f32[s1, s2] = torch.ops.aten.cos.default(arg1_1); arg1_1 = None
sin: f32[s2] = torch.ops.aten.sin.default(arg5_1)
add: f32[s1, s2] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
sin_1: f32[s1, s2] = torch.ops.aten.sin.default(arg2_1)
cos_1: f32[s2] = torch.ops.aten.cos.default(arg5_1)
mul: f32[s1, s2] = torch.ops.aten.mul.Tensor(sin_1, cos_1)
is_same_size = torch.ops.aten.is_same_size.default(add, arg3_1); add = None
is_same_size_1 = torch.ops.aten.is_same_size.default(mul, arg4_1); mul = None
mul_1: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, sin_1); sin_1 = None
mul_2: f32[s1, s2] = torch.ops.aten.mul.Tensor(arg4_1, cos_1); arg4_1 = cos_1 = None
sum_1: f32[1, s2] = torch.ops.aten.sum.dim_IntList(mul_1, [0], True); mul_1 = None
sym_size: Sym(s2) = torch.ops.aten.sym_size(arg5_1, 0)
view: f32[s2] = torch.ops.aten.view.default(sum_1, [sym_size]); sum_1 = None
#
sin_2: f32[s2] = torch.ops.aten.sin.default(arg5_1)
neg: f32[s2] = torch.ops.aten.neg.default(sin_2); sin_2 = None
mul_3: f32[s2] = torch.ops.aten.mul.Tensor(view, neg); view = neg = None
cos_2: f32[s1, s2] = torch.ops.aten.cos.default(arg2_1); arg2_1 = None
mul_4: f32[s1, s2] = torch.ops.aten.mul.Tensor(mul_2, cos_2); mul_2 = cos_2 = None
sum_2: f32[1, s2] = torch.ops.aten.sum.dim_IntList(arg3_1, [0], True); arg3_1 = None
view_1: f32[s2] = torch.ops.aten.view.default(sum_2, [sym_size]); sum_2 = sym_size = None
cos_3: f32[s2] = torch.ops.aten.cos.default(arg5_1); arg5_1 = None
mul_5: f32[s2] = torch.ops.aten.mul.Tensor(view_1, cos_3); view_1 = cos_3 = None
add_1: f32[s2] = torch.ops.aten.add.Tensor(mul_3, mul_5); mul_3 = mul_5 = None
return [None, None, mul_4, add_1]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100494
Approved by: https://github.com/zou3519