pytorch
dd1f2952 - [spmd] Improve activation handling, factory ops and batch dim reduction (#100853)

Commit
1 year ago
[spmd] Improve activation handling, factory ops and batch dim reduction (#100853) This PR improves the activation handling logic of data parallel, to support the cases where there're tensor factory ops that does not depend on any input node, it would still produce activation, with either sharded act (i.e. if output shape have batch size) or replcate act It also significantly simplify the full reduction logic, now we don't need the full reduction detection, we only need to ensure that when compute the batch dim, we detected full reduction and mark it as sharded Pull Request resolved: https://github.com/pytorch/pytorch/pull/100853 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading