Load sharded pt to flax (#18419)
* initial commit
* add small test
* add cross pt tf flag to test
* fix quality
* style
* update test with new repo
* fix failing test
* update
* fix wrong param ordering
* style
* update based on review
* update related to recent new caching mechanism
* quality
* Update based on review
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
* quality and style
* Update src/transformers/modeling_flax_utils.py
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>