jax
6deeee27 - [export] Fix device assignment error for grad of exported.

Commit
1 year ago
[export] Fix device assignment error for grad of exported. Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. This PR fixes that for the case when the export is done through the direct use of `jax.experimental.export`, and leaves as future work the case when the use is from `jax2tf`. We add a disabled tests for the latter case. Bug: #21314
Author
Committer
Parents
Loading