transformers
7d6fa935 - Fix ShieldGemma2 non-reproducible outputs by adding _tied_weights_keys (#44358)

Commit
9 days ago
Fix ShieldGemma2 non-reproducible outputs by adding _tied_weights_keys (#44358) * Fix ShieldGemma2 non-reproducible outputs by adding _tied_weights_keys The checkpoint has `text_config.tie_word_embeddings = True`, meaning `lm_head.weight` should be tied to `embed_tokens.weight`. However, `ShieldGemma2ForImageClassification` was missing `_tied_weights_keys`, so `from_pretrained` treated `model.lm_head.weight` as absent and left it randomly initialized on every load — causing non-reproducible outputs. Adding `_tied_weights_keys = {"model.lm_head.weight": "model.model.language_model.embed_tokens.weight"}` lets the loading machinery skip that key and populate it via `tie_weights()` in `from_pretrained`, which uses the dict-based tying mechanism. Fixes: https://huggingface.co/google/shieldgemma-2-4b-it/discussions/10 Co-authored-by: Hardik Meisheri <hardik.meisheri@gmail.com> Co-authored-by: Shrey Ganatra <ganatrashrey2002@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix ShieldGemma2 weight tying by propagating tie_word_embeddings to outer config * Adding doc strings for tie_word_embeddings for shieldGemma2 * Update src/transformers/models/shieldgemma2/modeling_shieldgemma2.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: Shrey Ganatra <ganatrashrey2002@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Parents
Loading