[export] Fix device assignment error for grad of exported.
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.
This PR fixes that for the case when the export is done through
the direct use of `jax.experimental.export`, and leaves as future
work the case when the use is from `jax2tf`. We add a disabled
tests for the latter case.
Bug: #21314