Substantially reduce memory usage in _update_causal_mask for large batches by using .expand instead of .repeat [needs tests+sanity check] (#29413)
* try to fix gemma mem use
* fix: handle attention mask dim==2 case
* remove logits=logits.float()
* clean up + add llama
* apply formatting
* readability edit: swap order of items being multiplied
* revert change unrelated to PR
* revert black autoformat
* switch to one .to
* Accept style edits
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>