jax
1a87fd3b - Implement a proper shape checking rule for gather. (#4166)

Commit
5 years ago
Implement a proper shape checking rule for gather. (#4166) * Implement a proper shape checking rule for gather. The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. Fixes google/jax#2826, and in principle fixes google/jax#4154 and google/jax#3905. * Extracted common functions for gather/scatter shape checking rules.
Author
Parents
Loading