Relax restrictions of torch.autocast integration (#7543)
This PR relaxes two restrictions on torch.autocast in the DeepSpeed
engine:
1) Nesting torch.autocast
Currently, we do not expect `torch.autocast` to be used outside the
DeepSpeed engine. Here is the current behavior:
- If `torch.autocast` is enabled in the DeepSpeed config and the engine
detects it is also enabled outside, a warning is displayed.
- If it is disabled in the config, the engine raises an error.
This design prevents the following usage:
```python
with torch.autocast(...):
logits = deepspeed_model(...)
loss = criteria_fn(logits)
```
In this case, we also want to apply autocast to `criteria_fn`. With the
current behavior, we would need move `deepspeed_model(...)` outside the
`torch.autocast` context, leading to inconsistent code between DeepSpeed
and non-DeepSpeed setups. (cannot be handled with `enabled` arg of
`torch.autocast`)
Change in this PR:
`torch.autocast` outside the DeepSpeed engine is ignored, and
- If `torch_autocast` is enabled in the config, DeepSpeed will follow
that setting.
- If it is disabled, DeepSpeed falls back to its own mixed-precision
support (or FP32).
In these cases, DeepSpeed engine shows a message to explain the
behavior.
2) Model’s dtype
Previously, DeepSpeed assumed the model’s dtype must be FP32 when
`torch.autocast` was enabled. However, models with lower-precision
parameters (e.g., BF16) can also be used with autocast. For example, if
both the model and `torch.autocast` use BF16, autocast will upcast
precision-sensitive ops as needed.
Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.
This PR also adds and updates tests to cover these new behaviors.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>