Make DimConstraints create actionable message (#100103)
This pr makes summary of dimension constraints actionable. Before the pr, it will print:
```
torch.fx.experimental.symbolic_shapes: [WARNING] Summary of dimension constraints:
The following dimensions have been specialized and CANNOT be dynamic.
NOTE: Specializations will happen by default with `assume_static_by_default=True`.
L['c'].size()[1] == 3
L['a'].size()[2] == 3
L['a'].size()[1] == 3
L['b'].size()[2] == 2
L['b'].size()[1] == 2
L['c'].size()[2] == 3
The following dimensions CAN be dynamic.
You can use the following code to specify the constraints they must satisfy:
'''
constraints=[
dynamic_dim(L['c'], 0) == dynamic_dim(L['a'], 0),
2 <= dynamic_dim(L['b'], 0),
2 <= dynamic_dim(L['a'], 0),
]
'''
```
Users need to initialize the L environment manually and copy the constraints over. After the pr, we have:
```
[2023-04-26 05:43:12,849] torch._dynamo.eval_frame: [WARNING] Summary of dimension constraints:
The following dimensions have been specialized and CANNOT be dynamic.
NOTE: Specializations will happen by default with `assume_static_by_default=True`.
'''
def specializations(a, b, c):
return (a.size()[2] == 3 and
c.size()[1] == 3 and
a.size()[1] == 3 and
c.size()[2] == 3 and
b.size()[2] == 2 and
b.size()[1] == 2)
'''
The following dimensions CAN be dynamic.
You can use the following code to specify the constraints they must satisfy:
'''
def specify_constraints(a, b, c):
return [
2 <= dynamic_dim(b, 0),
dynamic_dim(c, 0) == dynamic_dim(a, 0),
2 <= dynamic_dim(a, 0),
]
'''
```
, where dynamic_constraints has the same input signature as users code. This allow users to copy-paste and run the code to generate the constraints before exporting as shown below:
```
def specify_constraints(a, b, c):
return [
2 <= dynamic_dim(b, 0),
dynamic_dim(c, 0) == dynamic_dim(a, 0),
2 <= dynamic_dim(a, 0),
]
torch._dynamo.export(my_dyn_fn, x, y, z, constraints=specify_constriants(x, y, z))
```
Implementation-wise, this pr also
1. changes shape_env.produce_guards to produce_guards_and_constraints,
2. adds contraints_export_fn hooks,
The purpose is to surface the DimConstraints to dynamo.export, where we could reliably get the original function's signature.
The alternative to the above is to get the function signature before creating SHAPE_ENV guard (https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/output_graph.py#L227) and pass it to DimConstraints, but I couldn't recover the signature before creating SHAPE_ENV because the frame's f_globals/locals don't contain the original function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100103
Approved by: https://github.com/guangy10, https://github.com/tugsbayasgalan