torchdynamo and xla integration (#87741)
# Motivation
- torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing
- guard system is quite low overhead but torchxla tracing overhead is quite high
The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that
- we can integration torchdynamo with XLA
- we reduce or even completely avoid tracing overhead of torchxla
# Technique details
## XLA baseline
We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA.
## New dynamo backends added
We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets.
torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline.
torchxla_trace_once only traces once during AOT compiling time. Here are the steps:
1. dynamo capture guards and the subgraph
2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup
3. at inference time, the hash is used directly to lookup the optimized graph and run it.
# Limitations
We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models.
# Results
The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once and torchxla_trivial is mostly neutral as expected.
```
| Model | XLA (trace once) | XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18 | 1.346 | 1.045 |
+-------------------------+--------------------+-------------------------+
| resnet50 | 1.153 | 1.007 |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d | 1.381 | 1.039 |
+-------------------------+--------------------+-------------------------+
| alexnet | 1.045 | 1.018 |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2 | 1.562 | 1.021 |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0 | 1.303 | 1.069 |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1 | 1.278 | 1.025 |
+-------------------------+--------------------+-------------------------+
| vgg16 | 1.076 | 1.008 |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch | 2.224 | 0.978 |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer | 1.81 | 1.025 |
+-------------------------+--------------------+-------------------------+
| geomean | 1.38101 | 1.02324 |
+-------------------------+--------------------+-------------------------+
```
The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there):
https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5
# Next steps
- Use AOT autograd to enable training
- Share results on XLA devices
- Do more extensive tests on torchbench models
Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once
```
Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: https://github.com/pytorch/xla/pull/4119
topic: not user facing
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87741
Approved by: https://github.com/wconstab