Implemented staged tests for train
Summary:
Some models would like to implement three-staged process for the train loop so that contexts can be applied on different stages.
The name of the three stages we introduce is: `forward()`, `backward()`, and `optimizer()`. Some run contexts, such as `amp` and `torchdynamo`, may want to only apply to the `forward()` stage, not `backward()` or `optimizer()`.
Reviewed By: jspark1105
Differential Revision: D40155380
fbshipit-source-id: e515eb9f85cad54348159286ba895e8d60a3dd8e