add flatten parameter module (#66578)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66578
flatten parameters for performance optimization and handle the case when grad ready order is different or there are unused parameters among ranks. when there is no param to be sharded in the FSDP instance (usually root), the flatten wrapper module's flat_param is None.
ghstack-source-id: 140696745
Test Plan: unit test
Reviewed By: mrshenli
Differential Revision: D31625194
fbshipit-source-id: c40e84f9154f5703e5bacb02c37c59d6c4e055c7