Flax memory efficient attention (#2889)
* add use_memory_efficient params placeholder
* test
* add memory efficient attention jax
* add memory efficient attention jax
* newline
* forgot dot
* Rename use_memory_efficient
* Keep dtype last.
* Actually use key_chunk_size
* Rename symbol
* Apply style
* Rename use_memory_efficient
* Keep dtype last
* Pass `use_memory_efficient_attention` in `from_pretrained`
* Move JAX memory efficient attention to attention_flax.
* Simple test.
* style
---------
Co-authored-by: muhammad_hanif <muhammad_hanif@sofcograha.co.id>
Co-authored-by: MuhHanif <48muhhanif@gmail.com>