pytorch
a5f32f89 - training support for dynamo+torchxla integration (#88449)

Commit
1 year ago
training support for dynamo+torchxla integration (#88449) We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla. Training is trickier than inference and we may not expect much perf gains since 1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations. 2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency. 3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces. But we still want to add training support to dynamo/torchxla to make the work complete. We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above. Results: With '--iterations-per-run' equals to 1, here are the perf numbers: ``` +-------------------------+--------------------+-------------------------+ | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 0.91 | 0.959 | +-------------------------+--------------------+-------------------------+ | resnet50 | 0.917 | 0.932 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 0.912 | 0.905 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.038 | 0.974 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 0.881 | 0.835 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 0.903 | 0.931 | +-------------------------+--------------------+-------------------------+ | vgg16 | 0.914 | 0.967 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 1.359 | 0.84 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 1.288 | 0.893 | +-------------------------+--------------------+-------------------------+ | geomean | 1.0006 | 0.913794 | +-------------------------+--------------------+-------------------------+ ``` Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run' NOTE: In torchbench.py I added the following code to do a few workaround: ``` from myscripts import workaround # TODO will remove this line before landing ``` Here are the content of workaround.py: ``` import torch from torch import nn import os # override max_pool2d with avg_pool2d if os.environ.get("REPLACE_MAXPOOL", "0") == "1": torch.nn.MaxPool2d = torch.nn.AvgPool2d ``` It work around a few issues we found 1. MaxPool2d does not work for training in dynamo/torchxla: https://github.com/pytorch/torchdynamo/issues/1837 . WIP fix from Brian in https://github.com/pytorch/pytorch/pull/90226 , https://github.com/pytorch/xla/pull/4276/files (WIP) 2. recent change ( this PR https://github.com/pytorch/pytorch/pull/88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in https://github.com/pytorch/xla/pull/4282#event-7969608134 . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd) 3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: https://github.com/pytorch/xla/pull/4293 (confirmed the fix) Example command: ``` REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88449 Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
Author
Committer
Parents
Loading