[export] Add and fix a test for exporting higher-order gradients with sharding
There was a test for export with gradients, we changed the test to
(a) export 2nd order gradient also, and (b) to export both with
a mesh context and without a mesh context (using NamedSharding).
This test currently fails, only in the case when we do NOT have a
mesh context, as explained below:
When exporting gradient functions, we first export the primal functions
and we use the in/out-shardings to construct shardings of the gradient
function. Since Exported shardings now contain only HloSharding objects,
and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we
construct GSPMDSharding objects using the current devices and the HloSharding
object from the Exported primal.
However, these objects do not have the `_original_sharding` attribute.
Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding`
using the mesh context (which is empty). This fails.
This PR contains one workaround, to skip `parse_flatten_op_sharding` if
the physical mesh of the `resource_env` is empty.
Another, probably better solution, is to ensure that `resource_env` is
`None` when then is no mesh context. That seemed reasonable, but currently
the code returns an empty mesh from the resource_env if there is no
mesh context. Changing this would have effects in more parts of the code,
so I have not done it here, but it may be worth doing.