Enable FSDP ``use_orig_params=True`` mixed precision training when some ranks have no (non-zero sized) parameter shards (#99175)
Fixes #99174
## Enable FSDP ``use_orig_params=True`` mixed precision training when some ranks have no (non-zero sized) parameter shards
### The issue
Now that ``use_orig_params=True`` allows non-uniform ``requires_grad`` (:tada: :rocket: thanks @awgu!!!) with [#98221](https://github.com/pytorch/pytorch/pull/98221), there will be circumstances wherein some ranks have no (non-zero sized) local shards of the original parameters (and hence no associated gradients).
### Use Cases
For a simple Transformer case, imagine a user wraps all encoder layers in separate FSDP instances but allows the classifier head to be wrapped in the same FSDP instance as the relatively large embeddings layers. While this is a sub-optimal wrapping strategy for most use-cases, I believe it is expected to be supported (full precision training works in that context).
I originally encountered this issue while extending a package I maintain, leveraging the relaxed ``requires_grad`` contstraint to simplify multi-phase scheduled fine-tuning FSDP configuration, so a [concrete example is there](https://finetuning-scheduler.readthedocs.io/en/latest/advanced/fsdp_scheduled_fine_tuning.html#basic-scheduled-fine-tuning-with-fsdp).
### Reproduction and Remediation
Currently, ``ShardedGradScaler`` does not accommodate these situations, failing to initialize ``optimizer_state["found_inf_per_device"]`` when ``unscale_`` is called.
In this PR, I extend the existing ``ShardedGradScaler`` tests with an ``use_orig_params=True`` dimension added to the parameterization and test scenarios wherein one rank possesses no (non-zero sized) parameter shards.
The relevant issue can be reproduced with the tests I'm adding in this PR. The current (pre-PR) execution of these tests fail in ``use_orig_params=True`` mode with this error:
```python
./test_fsdp_sharded_grad_scaler.py::TestShardedGradScalerParityWithDDP::test_fsdp_ddp_parity_with_grad_scaler_offload_false_none_mixed_precision_use_orig_params Failed with Error: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 657, in run_test
getattr(self, test_name)()
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 543, in wrapper
fn()
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_utils.py", line 259, in instantiated_test
test(self, **param_kwargs)
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
return func(*args, **kwargs)
File "/home/speediedan/repos/pytorch/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py", line 187, in test_fsdp_ddp_parity_with_grad_scaler
self._test_fsdp_parity(
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_fsdp.py", line 1152, in _test_fsdp_parity
fsdp_loss = self._train_for_several_steps(
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_fsdp.py", line 1016, in _train_for_several_steps
sharded_grad_scaler.step(optim)
File "/home/speediedan/repos/pytorch/torch/distributed/fsdp/sharded_grad_scaler.py", line 291, in step
return super().step(optimizer, *args, **kwargs)
File "/home/speediedan/repos/pytorch/torch/cuda/amp/grad_scaler.py", line 368, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
```
A few implementation notes/considerations and questions:
1. Rather than just initialize ``per_device_found_inf``, one could disable the grad scalar altogether for relevant ranks, altering ``unscale_`` to reduce with a subgroup or some rank mask construct to avoid the ``all_reduce`` s in ``distributed/fsdp/sharded_grad_scaler.py:unscale_()`` from hanging. Given that users may subsequently add parameter groups to an optimizer that would require re-enabling the scaler and the complexity associated with maintaining a separate mask construct or process subgroup, I thought this implementation was cleaner.
2. I extended ``_train_for_several_steps`` and ``_test_fsdp_parity`` in ``/torch/testing/_internal/common_fsdp.py`` with the ability to configure ``sharded_grad_scaler_kwargs`` for future testing flexibility.
3. Should the user be warned that no parameter shards were associated with a given rank? My initial thought is that this should be considered an implementation detail, part of supporting ``use_orig_params`` with heterogeneous ``requires_grad``, and therefore should be transparently handled by PyTorch. Should a DEBUG level message be added? If so, likely further upstream rather than at the scaler step level.
4. Rather than extend the existing ``ShardedGradScaler`` tests with an ``use_orig_params=True`` dimension added to the parameterization, let me know if you prefer that I instead narrow the scope of the new testing to a single additional test, e.g.:
```python
# from typing import Optional
from typing import Optional, List
# ...
# use_orig_params = ["enable_use_orig_params", None]
use_orig_params: List[Optional[str]] = [None]
# ...
configs = list(itertools.product(cpu_offload_config, sharding_strategy_config, mixed_precision, use_orig_params))
configs.append((CPUOffload(offload_params=False), None, "enable_mixed_precision", "enable_use_orig_params"))
```
Thanks as always to the PyTorch distributed team for your astonishingly impressive and valuable contributions to the open-source ML engineering community!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99175
Approved by: https://github.com/awgu