pytorch
379fb476 - [SPMD] Support foreach optimizers with functionalization (#97853)

Commit
2 years ago
[SPMD] Support foreach optimizers with functionalization (#97853) My first attempt was to apply the same solution as how proxy_tensor.py handles other inplace ops. However, foreach is different in the way that it's schema is `native_functions.yaml` does not return anything, whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks bdhirsh for teaching me this!). As a result, the proxy output during tracing does not wrap anything, and hence we cannot correctly connect it with subsequent operators. Modifying `native_functions.yaml` is not a preferred solution. After discussing with bdhirsh, the temporary solution is to do foreach functionalization as a graph pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852 is addressed, we will switch to default functionalization. Edit: the latest version follows @bdhirsh 's suggestion on using `make_fx` `decomposition_table` instead of implementing manual fx.Graph tranforms to functionalize `_foreach_add_`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853 Approved by: https://github.com/fegin, https://github.com/wanchaol
Author
Committer
Parents
Loading