[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)
* [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes
* [FlaxRoberta] Fix non-broadcastable attention mask
* Use jax.numpy instead of ordinary numpy (otherwise not jit-able)
* Partially revert "Use jax.numpy ..."
* Add tests for batched forward passes
* Avoid unnecessary OOMs due to preallocation of GPU memory by XLA
* Auto-fix style
* Re-enable GPU memory preallocation but with mem fraction < 1/paralleism