[FSDP2] Added mixed precision (#118223)
This PR adds mixed precision configured via `MixedPrecisionPolicy`.
- By default (`cast_forward_inputs=True`), each FSDP module will cast forward floating-point input tensors to `param_dtype` if specified. If the user wants to own the cast, then the user can disable it by passing `False`.
- Symmetrically, by default (`output_dtype=None`) each FSDP module will not cast the forward output. If the user wants to customize the output dtype, then the user can pass a `torch.dtype`.
- `param_dtype` configures the unsharded parameters' dtype for forward/backward computation and hence the all-gather dtype.
- `reduce_dtype` configures the gradient reduction dtype. If `reduce_dtype=None` and `param_dtype is not None`, then `reduce_dtype` inherits from `param_dtype` for simplicity.
We test against a manually implemented reference implementation instead of comparing against existing FSDP since the comparison is more direct to what we want to test.
---
**Overhead benchmarks to inform design**
The dilemma is as follows:
- The common path for FSDP is bf16 parameter mixed precision, where we cast sharded parameters from fp32 to bf16 before all-gathering them.
- The baseline implementation is to `torch._foreach_copy_` the sharded parameters to the flat `all_gather_input`, which gets passed to `dist.all_gather_into_tensor`.
- The baseline incurs 1 extra fp32 read and 1 extra bf16 write per parameter because `_foreach_copy` takes the slow path, calling `copy_` in a loop, and `copy_` calls `dst.copy_(src.to(bf16))` where `dst` is bf16 and `src` is fp32.
- These `copy_` calls stay in C++ and do not require calling `at::as_strided`.
- The issue with this baseline implementation is that it requires knowing that all parameters in the group will be cast from fp32 to bf16 to do this `_foreach_copy_` from fp32 sources to a bf16 destination.
- We want per-parameter FSDP to support mixed dtype all-gathers, which would involve different parameters providing different dtype all-gather inputs and viewing them as uint8 for a combined flat all-gather input, where this viewing-as-uint8 step is only needed in the mixed dtype case.
- However, this incurs more CPU overhead, so we want to investigate this in more detail.
We consider 150 `nn.Parameter`s with shapes taken from an internal model (where the shapes only affect the copy bandwidth, not the CPU overhead). We focus on world size 128 first. We consider two experiments: (1) run the copy-in with no head start, allowing CPU boundedness affect GPU time, and (2) run the copy-in with a CPU head start, removing CPU overhead from affecting GPU time.
No head start:
- Baseline `torch._foreach_copy_`: 0.525 ms CPU; 0.528 ms GPU
- `.to(bf16)` before `torch._foreach_copy_`: 0.828 ms CPU; 0.836 ms GPU
- `.to(bf16).view(uint8)` before `torch._foreach_copy_`: 0.933 ms CPU; 0.937 ms GPU
Head start (removing CPU boundedness from GPU times):
- Baseline `torch._foreach_copy_`: 0.393 ms GPU
- `.to(bf16)` before `torch._foreach_copy_`: 0.403 ms GPU
- `.to(bf16).view(uint8)` before `torch._foreach_copy_`: 0.403 ms GPU
Some other interesting notes:
- Constructing a set of all all-gather input dtypes: ~0.015 ms -- this would be the overhead cost of checking whether we need to view as uint8 (i.e. whether we have mixed dtype); alternatively, we could always view as uint8 (but that loses the mixed precision policy info from the profiler trace)
- Changing from `[t.to(bf16).view(uint8) for t in ts]` to two list comprehensions like `[t.to(bf16) for t in ts]; [t.view(uint8) for t in ts]` actually reduces CPU overhead 🤔 (by ~0.04 ms)
We see that the main difference is just CPU overhead. The GPU times are almost the same. (Actually, sweeping over 8, 16, 32, 64 world size, we do see difference in GPU time inversely proportional to world size, as expected since smaller world sizes copy more data. However, even at world size 8, the difference is only 0.407 ms vs. 0.445 ms GPU time.) Note though that the CPU overhead differences are exacerbated when the PyTorch profiler is turned on, and how much so seems to depend on the CPU capability.
Seeing these numbers, I am inclined to prefer to just incur the CPU overhead, especially given that if we want to support the mixed dtype case for fp8 all-gather, we will need to incur this anyway. If the CPU overhead becomes a problem on a real workload, then we will need to figure out options then, one being using `torch.compile` possibly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118223
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #119550, #118136