xla
3c83269f - Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA (#3431)

Commit
3 years ago
Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA (#3431) * Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA * move the FSDP module to `torch_xla.distributed` * adding `mark_step_on_freeing` as a temp workaround to #3455 * check in __init__ whether the module is already FSDP; fix exception types * add `optimization_barrier_` (https://github.com/pytorch/xla/pull/3493) to avoid fusion of full parameter reconstruction with subsequent freeing * also apply `xm.optimization_barrier_` to FSDP output's gradients * deprecate `mark_step_on_freeing` (since we have optimization barrier now) * add option to run a dummy forward pass in FSDP * add `_shard_size_multiple` to make sharded parameters a multiple of 128 for efficient all-gather (see https://github.com/pytorch/xla/issues/3510#issuecomment-1101739677) * refactor optimization_barrier_ to separately apply to forward and backward pass `_rebuild_full_params` and `_free_full_params` * seal off more relevant ops w/ optimization_barrier_ to avoid undesired fusion * remove obsolete `mark_step_on_freeing` and `use_all_gather_via_all_reduce` configs; unpin layout for all_reduce; add a wrapper for gradient checkpointing on modules; remove redundant `param_names` * handle keyword arguments in `checkpoint_module` * add gradient checkpointing option to MNIST and ImageNet FSDP examples * refactor `optimization_barrier` and only apply it in forward or backward when specified * refactor command line tool to consolidate sharded checkpoints * address reviewers' comments from GitHub * add more user instructions for checkpoint consolidation * change `flatten_parameters` default to False since it didn't bring an actual speed up in tests and breaks optimizer groups * documentation refinement
Author
Parents
Loading