DeepSpeed
6e5d58d2 - OptimizedLinear updates (#5791)

Commit
1 year ago
OptimizedLinear updates (#5791) This is a refresh of of `OptimizedLinear` with the following features to improve performance and usability: * More efficient sharing of base weights using `all_gather_into_tensor` * Flattened sharded weights * Selectively offload frozen weights to cpu * `deepspeed.linear.Init` that allows injecting OptimizedLinear during model construction (similar to zero.Init) * Support for load state dict directly in OptimizedLinear, this allows loading HF model weights correctly into sharded params * Various bug fixes for the LoRA implementation introduced previously * Several new unit tests Builds on-top of @RezaYazdaniAminabadi's previous FP8 updates (#5764) to support dense model fp8 quantization. Example usage of this to fine-tune llama-3.1-405B on a single node: https://github.com/Snowflake-Labs/snowflake-arctic/tree/main/training/llama3.1 --------- Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com> Co-authored-by: Reza Yazdani <152926435+sfc-gh-reyazda@users.noreply.github.com>
Author
Parents
Loading