Flax support for Stable Diffusion 2 (#1423)
* Flax: start adapting to Stable Diffusion 2
* More changes.
* attention_head_dim can be a tuple.
* Fix typos
* Add simple SD 2 integration test.
Slice values taken from my Ampere GPU.
* Add simple UNet integration tests for Flax.
Note that the expected values are taken from the PyTorch results. This
ensures the Flax and PyTorch versions are not too far off.
* Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Typos and style
* Tests: verify jax is available.
* Style
* Make flake happy
* Remove typo.
* Simple Flax SD 2 pipeline tests.
* Import order
* Remove unused import.
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: @camenduru