Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA (#3431)
* Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA
* move the FSDP module to `torch_xla.distributed`
* adding `mark_step_on_freeing` as a temp workaround to #3455
* check in __init__ whether the module is already FSDP; fix exception types
* add `optimization_barrier_` (https://github.com/pytorch/xla/pull/3493) to avoid fusion of full parameter reconstruction with subsequent freeing
* also apply `xm.optimization_barrier_` to FSDP output's gradients
* deprecate `mark_step_on_freeing` (since we have optimization barrier now)
* add option to run a dummy forward pass in FSDP
* add `_shard_size_multiple` to make sharded parameters a multiple of 128 for efficient all-gather (see https://github.com/pytorch/xla/issues/3510#issuecomment-1101739677)
* refactor optimization_barrier_ to separately apply to forward and backward pass `_rebuild_full_params` and `_free_full_params`
* seal off more relevant ops w/ optimization_barrier_ to avoid undesired fusion
* remove obsolete `mark_step_on_freeing` and `use_all_gather_via_all_reduce` configs; unpin layout for all_reduce; add a wrapper for gradient checkpointing on modules; remove redundant `param_names`
* handle keyword arguments in `checkpoint_module`
* add gradient checkpointing option to MNIST and ImageNet FSDP examples
* refactor `optimization_barrier` and only apply it in forward or backward when specified
* refactor command line tool to consolidate sharded checkpoints
* address reviewers' comments from GitHub
* add more user instructions for checkpoint consolidation
* change `flatten_parameters` default to False since it didn't bring an actual speed up in tests and breaks optimizer groups
* documentation refinement