439 | 440 | ||
440 | 441 | xla::PjRtDevice* pjrt_device = | |
441 | 442 | StringToPjRtDevice(instance.compilation_device); |
can you add this env var to https://github.com/pytorch/xla/blob/master/configuration.yaml?
I'd add a comment here saying this variable is experimental; hence it's set to false
. When we have a stable state, we should set it to true
.
Let's run a benchmark with / without this PR and determine performance. We can start with resnet and mnist.
cc @will-cromar, as the change is made on PjrtComputationClient
.
LGTM. Thanks!
I don't have any concerns for performance in this PR because you can simply disable XLA_STABLEHLO_COMPILE
.
Let's run a benchmark with / without this PR and determine performance. We can start with resnet and mnist.
Sure, will do that on ResNet and MNIST and share the results.
Measured ResNet18 compile and execution time on TPU v4. StableHLO takes more time to compile as it needs to be converted back to HLO at some point in XLA PjRt Client. Execution time is similar.
Script: https://gist.github.com/lsy323/057348110ba439b58861400d35ef5ff0
Command to run:
Compile with StableHLO: PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=1 python test_stablehlo_latency.py
Compile with HLO: PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=0 python test_stablehlo_latency.py
Output (StableHLO):
root@t1v-n-617b7043-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=1 python test/stablehlo/test_stablehlo_latency.py
tensor([[-3.0977, -0.9105, -2.2717, ..., 2.0278, -4.0764, -3.4962],
[-3.0307, -0.8453, -2.2413, ..., 2.0563, -4.0359, -3.5483],
[-3.0857, -0.8055, -2.3006, ..., 2.0095, -4.0301, -3.5824],
[-3.0294, -0.9413, -2.2699, ..., 2.1221, -4.1136, -3.6157]],
device='xla:0', dtype=torch.float64, grad_fn=<AddmmBackward0>)
Metric: CompileTime
TotalSamples: 1
Accumulator: 04s760ms645.719us
Percentiles: 1%=04s760ms645.719us; 5%=04s760ms645.719us; 10%=04s760ms645.719us; 20%=04s760ms645.719us; 50%=04s760ms645.719us; 80%=04s760ms645.719us; 90%=04s760ms645.719us; 95%=04s760ms645.719us; 99%=04s760ms645.719us
Metric: ExecuteTime
TotalSamples: 1
Accumulator: 006ms434.110us
Percentiles: 1%=006ms434.110us; 5%=006ms434.110us; 10%=006ms434.110us; 20%=006ms434.110us; 50%=006ms434.110us; 80%=006ms434.110us; 90%=006ms434.110us; 95%=006ms434.110us; 99%=006ms434.110us
Counter: StableHloCompile
Value: 1
Output (HLO):
root@t1v-n-617b7043-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=0 python test/stablehlo/test_stablehlo_latency.py
tensor([[ 1.5539, -1.4548, -1.3284, ..., -0.4333, -4.3026, -3.7028],
[ 1.7148, -1.4618, -1.4548, ..., -0.6571, -4.3884, -3.7598],
[ 1.6167, -1.4881, -1.3145, ..., -0.6026, -4.2327, -3.7895],
[ 1.6563, -1.4699, -1.3832, ..., -0.5625, -4.3638, -3.7374]],
device='xla:0', dtype=torch.float64, grad_fn=<AddmmBackward0>)
Metric: CompileTime
TotalSamples: 1
Accumulator: 04s696ms979.632us
Percentiles: 1%=04s696ms979.632us; 5%=04s696ms979.632us; 10%=04s696ms979.632us; 20%=04s696ms979.632us; 50%=04s696ms979.632us; 80%=04s696ms979.632us; 90%=04s696ms979.632us; 95%=04s696ms979.632us; 99%=04s696ms979.632us
Metric: ExecuteTime
TotalSamples: 1
Accumulator: 006ms091.800us
Percentiles: 1%=006ms091.800us; 5%=006ms091.800us; 10%=006ms091.800us; 20%=006ms091.800us; 50%=006ms091.800us; 80%=006ms091.800us; 90%=006ms091.800us; 95%=006ms091.800us; 99%=006ms091.800us
cc @miladm
19 | # Run ResNet on XLA device. | ||
20 | device = xm.xla_device() | ||
21 | # materalize the fake data for test purpose | ||
22 | xm.mark_step() |
I wonder why we need to do mark_step. The comment says "materalize the fake data for test purpose" but we don't have a HLO graph at this point.
I was referencing the code in dynamo test. I this the purpose is to clear the graph from previous tests.
Benchmarked on Transformer (from torch.nn.transformer
), no degradation observed in execution time
script
Compile with StableHLO
Metric: CompileTime
TotalSamples: 1
Accumulator: 13s679ms954.216us
Percentiles: 1%=13s679ms954.216us; 5%=13s679ms954.216us; 10%=13s679ms954.216us; 20%=13s679ms954.216us; 50%=13s679ms954.216us; 80%=13s679ms954.216us; 90%=13s679ms954.216us; 95%=13s679ms954.216us; 99%=13s679ms954.216us
Metric: ExecuteTime
TotalSamples: 20
Accumulator: 112ms122.061us
ValueRate: 002ms612.975us / second
Rate: 0.287718 / second
Percentiles: 1%=004ms238.920us; 5%=004ms327.740us; 10%=004ms482.730us; 20%=005ms692.209us; 50%=005ms805.449us; 80%=005ms962.460us; 90%=006ms530.290us; 95%=021ms244.279us; 99%=021ms244.279us
Counter: StableHloCompile
Value: 1
Compile with HLO
Metric: CompileTime
TotalSamples: 1
Accumulator: 11s715ms884.137us
Percentiles: 1%=11s715ms884.137us; 5%=11s715ms884.137us; 10%=11s715ms884.137us; 20%=11s715ms884.137us; 50%=11s715ms884.137us; 80%=11s715ms884.137us; 90%=11s715ms884.137us; 95%=11s715ms884.137us; 99%=11s715ms884.137us
Metric: ExecuteTime
TotalSamples: 20
Accumulator: 129ms986.975us
ValueRate: 002ms878.180us / second
Rate: 0.29122 / second
Percentiles: 1%=005ms857.130us; 5%=005ms859.580us; 10%=005ms866.300us; 20%=005ms946.980us; 50%=005ms125.240us; 80%=005ms379.290us; 90%=006ms699.240us; 95%=031ms177.398us; 99%=031ms177.398us
@JackCaoG I think we can merge this one now.
Login to write a write a comment.
The StableHLO compilation flag is for experimental purpose, making sure the generated StableHLO can be consumed by XLA PjRt client.
XLA_STABLEHLO_COMPILE
to enable PjRt Compilation with StableHLO. The StableHLO is generated from HLO->StableHLO conversion function.XLA_STABLEHLO_COMPILE=1 python test/stablehlo/test_stablehlo_compile.py
cc @miladm