jax
8a2d4a01 - [export] Add and fix a test for exporting higher-order gradients with sharding

Commit
2 years ago
[export] Add and fix a test for exporting higher-order gradients with sharding There was a test for export with gradients, we changed the test to (a) export 2nd order gradient also, and (b) to export both with a mesh context and without a mesh context (using NamedSharding). This test currently fails, only in the case when we do NOT have a mesh context, as explained below: When exporting gradient functions, we first export the primal functions and we use the in/out-shardings to construct shardings of the gradient function. Since Exported shardings now contain only HloSharding objects, and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we construct GSPMDSharding objects using the current devices and the HloSharding object from the Exported primal. However, these objects do not have the `_original_sharding` attribute. Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding` using the mesh context (which is empty). This fails. This PR contains one workaround, to skip `parse_flatten_op_sharding` if the physical mesh of the `resource_env` is empty. Another, probably better solution, is to ensure that `resource_env` is `None` when then is no mesh context. That seemed reasonable, but currently the code returns an empty mesh from the resource_env if there is no mesh context. Changing this would have effects in more parts of the code, so I have not done it here, but it may be worth doing.
Author
Committer
Parents
Loading