[Flax] adding support for batch norm layers (#21581)
* [flax] adding support for batch norm layers
* fixing bugs related to pt+flax integration
* cleanup, batchnorm support in sharded pt to flax
* support for batchnorm tests in pt+flax integration
* simplifying checking batch norm layer