transformers
20164cc2 - RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285)

Commit
1 year ago
RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285) * Update modeling_llama.py Llama - Force float32 since bfloat16 loses precision on long contexts * Update modeling_llama.py * Update modeling_gemma.py Fix RoPE and logits.float() * @torch.no_grad() * @torch.no_grad() * Cos, Sin to float32 * cos, sin to float32 * Update src/transformers/models/gemma/modeling_gemma.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Resolve PR conflicts * Fix RoPE for llama * Revert "Fix RoPE for llama" This reverts commit b860a22dab9bb01cd15cb9a3220abeaefad3e458. * Fix RoPE for llama * RoPE device * Autocast device type * RoPE * RoPE isinstance --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Author
Committer
Parents
Loading