Improve error message when passing an invalid dtype. (#3405)
I spotted this when debugging an issue like:
self.assertEqual(x_jax, x_tf, check_dtypes=True)
The fix here is of course to use `x_tf.numpy()`, but it was not clear where the
error was from originally.