fix(models): Fix dtype mismatch in SwitchTransformers and TimmWrapperModel #45074
fix: Cast inputs to match weight dtype
cd1a4c94
new: Add test
e05bd6ad
harshaljanjani
marked this pull request as ready for review 5 days ago
Merge branch 'main' into fix/switch-transformers-timm-wrapper-bf16-dtype
32817246
change: Upcast to float32 instead of downcasting
faa66b3b
harshaljanjani
deleted the fix/switch-transformers-timm-wrapper-bf16-dtype branch 1 hour ago
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub