Fix inefficient recursive update in ShardedTensor.state_dict hook (#68806)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/68805
The bug is described in the linked issue. This PR is an attempt to make the functions `_recurse_update_dict` and `_recurse_update_module` more efficient in how they iterate over the submodules. The previous implementation was suboptimal, as it recursively called the update method on the submodules returned by `module.named_modules()`, while `module.named_modules()` already returned all submodules including nested ones.
cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68806
Reviewed By: pritamdamania87
Differential Revision: D33053940
Pulled By: wanchaol
fbshipit-source-id: 3e72822f65a641939fec40daef29c806af725df6