Reinplace auto_functionalized (#120829)
Fixes https://github.com/pytorch/pytorch/issues/120441
We follow how triton_kernel_wrapper_functional gets re-inplaced:
- If we see auto_functionalized, then first we compute what inputs we
actually need to clone ("tensors_to_clone") and fixup the graph. This happens in
`reinplace_and_refine_tensors_to_clone`, which I have refactored out
of the triton_kernel_wrapper_functional reinplacing code.
- Later on, after the reinplacing pass, we have a decomposition pass for
auto_functionalized. In that decomposition pass, we make use of the
"tensor_to_clone" info and only clone those inputs in the
decomposition.
- We shepherd "tensor_to_clone" from the first step to the second step
by setting the .meta field on the auto_functionalized node.
Test Plan:
- existing tests
- tested locally by reading the output of TORCH_LOGS="post_grad_graphs"
- added assertExpectedInline tests for the post_grad_graphs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120829
Approved by: https://github.com/oulgen