Implement a proper shape checking rule for gather. (#4166)
* Implement a proper shape checking rule for gather.
The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for the
particular setting of JAX. Fixes google/jax#2826, and in principle
fixes google/jax#4154 and google/jax#3905.
* Extracted common functions for gather/scatter shape checking rules.