Inconsistency in PreTrainedModel.resize_token_embeddings When ZeRO3 Is Enabled (#25394)
* Inconsistency in PreTrainedModel.resize_token_embeddings
This PR addresses https://github.com/huggingface/transformers/issues/25241.
In previous implementation when ZeRO stage 3 was enbaled, resize_token_embeddings would create independent PyTorch weights on each device. Here we ensure that new embeddings are created with DeepSpeed init, and are properly partitioned accros devices.
* formatting with black
* adding the removed comments back in
---------
Co-authored-by: Sina Moeini <smoeini@amazon.com>