[FSDP] Implement apply() (#72600)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72600
Implements `apply()` which applies a `callable` of signature `f(m: Module) -> None` recursively to every submodule. The main difference from `nn.module.apply` is that this version summons the full parameters before apply() so it works appropriately with FSDP.
ghstack-source-id: 149217423
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D34111109
fbshipit-source-id: 60d9d3f5c4d6c27763f5d68728dfb0bae3d9f644
(cherry picked from commit b20c65e06070f27fda0e5260f5cbbb41e3e33f46)