refactor prepare_for_backward (#54977)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54977
put part of codes in prepare_for_backward into functions, so that those functions can be used in static graph training and delay all reduce later on.
ghstack-source-id: 126366714
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D27439195
fbshipit-source-id: 8899eda621260232d774cb145f9c6d683c47e188