RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285)
* Update modeling_llama.py
Llama - Force float32 since bfloat16 loses precision on long contexts
* Update modeling_llama.py
* Update modeling_gemma.py
Fix RoPE and logits.float()
* @torch.no_grad()
* @torch.no_grad()
* Cos, Sin to float32
* cos, sin to float32
* Update src/transformers/models/gemma/modeling_gemma.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Resolve PR conflicts
* Fix RoPE for llama
* Revert "Fix RoPE for llama"
This reverts commit b860a22dab9bb01cd15cb9a3220abeaefad3e458.
* Fix RoPE for llama
* RoPE device
* Autocast device type
* RoPE
* RoPE isinstance
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>