ORTModule memory improvement (#18924)
## Dependency
https://github.com/microsoft/onnxruntime/pull/19007
## ORTModule memory efficient gradient management
Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.
This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.
## Design
Check the details from
https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ
## Convergence Validation:

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)
## TODO
Consolidate the logic with Stage3's similar logic.