Add population_count primitive to lax (#2753)
* add population_count primitive (needs new jaxlib)
fixes #2263
* Add popcount docs
* Add population_count to lax_reference
* Use int prng (since we're only testing uints)
Co-authored-by: Matthew Johnson <mattjj@google.com>