transformers
f8eda599 - [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)

Commit
5 years ago
[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791) * [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes * [FlaxRoberta] Fix non-broadcastable attention mask * Use jax.numpy instead of ordinary numpy (otherwise not jit-able) * Partially revert "Use jax.numpy ..." * Add tests for batched forward passes * Avoid unnecessary OOMs due to preallocation of GPU memory by XLA * Auto-fix style * Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
Parents
Loading