transformers
9264fc91 - Inconsistency in PreTrainedModel.resize_token_embeddings When ZeRO3 Is Enabled (#25394)

Commit
2 years ago
Inconsistency in PreTrainedModel.resize_token_embeddings When ZeRO3 Is Enabled (#25394) * Inconsistency in PreTrainedModel.resize_token_embeddings This PR addresses https://github.com/huggingface/transformers/issues/25241. In previous implementation when ZeRO stage 3 was enbaled, resize_token_embeddings would create independent PyTorch weights on each device. Here we ensure that new embeddings are created with DeepSpeed init, and are properly partitioned accros devices. * formatting with black * adding the removed comments back in --------- Co-authored-by: Sina Moeini <smoeini@amazon.com>
Author
Parents
Loading