OptimizedLinear updates (#5791)
This is a refresh of of `OptimizedLinear` with the following features to
improve performance and usability:
* More efficient sharing of base weights using `all_gather_into_tensor`
* Flattened sharded weights
* Selectively offload frozen weights to cpu
* `deepspeed.linear.Init` that allows injecting OptimizedLinear during
model construction (similar to zero.Init)
* Support for load state dict directly in OptimizedLinear, this allows
loading HF model weights correctly into sharded params
* Various bug fixes for the LoRA implementation introduced previously
* Several new unit tests
Builds on-top of @RezaYazdaniAminabadi's previous FP8 updates (#5764) to
support dense model fp8 quantization.
Example usage of this to fine-tune llama-3.1-405B on a single node:
https://github.com/Snowflake-Labs/snowflake-arctic/tree/main/training/llama3.1
---------
Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com>
Co-authored-by: Reza Yazdani <152926435+sfc-gh-reyazda@users.noreply.github.com>