Update on "refactor prepare_for_backward"
put part of codes in prepare_for_backward into functions, so that those functions can be used in static graph training and delay all reduce later on.
Differential Revision: [D27439195](https://our.internmc.facebook.com/intern/diff/D27439195/)
[ghstack-poisoned]