dynamo/torchxla integration: trace on xla rather than eager (#88904)
In #87741 we added the inference support for dynamo/torchxla integration. Later on in #88449 we attempt to add the training support. That attempt is not smooth because
- we try 2 things together
1. let dynamo trace the model on xla rather than eager
2. enable training
- It turns out neither of these two tasks are trivial enough.
Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync.
This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training.
Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x.
```
+-------------------------+--------------------+-------------------------+
| Model | XLA (trace once) | XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18 | 1.38 | 1.008 |
+-------------------------+--------------------+-------------------------+
| resnet50 | 1.227 | 0.998 |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d | 1.544 | 1.008 |
+-------------------------+--------------------+-------------------------+
| alexnet | 1.085 | 1.045 |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2 | 2.028 | 1.013 |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0 | 1.516 | 0.995 |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1 | 0.868 | 1.01 |
+-------------------------+--------------------+-------------------------+
| vgg16 | 1.099 | 1.008 |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch | 3.26 | 1.027 |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer | 2.182 | 1.015 |
+-------------------------+--------------------+-------------------------+
| geomean | 1.50389 | 1.01261 |
+-------------------------+--------------------+-------------------------+
```
Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88904
Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel