xla
Enable PjRt Client Compilation with StableHLO
#5233
Merged

Enable PjRt Client Compilation with StableHLO #5233

JackCaoG merged 6 commits into master from stablehlo-compile
lsy323
lsy3231 year ago (edited 1 year ago)

The StableHLO compilation flag is for experimental purpose, making sure the generated StableHLO can be consumed by XLA PjRt client.

  • User can set env var XLA_STABLEHLO_COMPILE to enable PjRt Compilation with StableHLO. The StableHLO is generated from HLO->StableHLO conversion function.
  • Added metrics counter for StableHLO compilation.
  • Added test. XLA_STABLEHLO_COMPILE=1 python test/stablehlo/test_stablehlo_compile.py

cc @miladm

lsy323 lsy323 assigned lsy323 lsy323 1 year ago
lsy323 lsy323 added stablehlo
JackCaoG
JackCaoG commented on 2023-06-22
torch_xla/csrc/runtime/pjrt_computation_client.cc
439440
440441 xla::PjRtDevice* pjrt_device =
441442
StringToPjRtDevice(instance.compilation_device);
JackCaoG1 year ago
miladm1 year ago

I'd add a comment here saying this variable is experimental; hence it's set to false. When we have a stable state, we should set it to true.

miladm
miladm commented on 2023-06-22
miladm1 year ago👍 1

Let's run a benchmark with / without this PR and determine performance. We can start with resnet and mnist.

lsy323
lsy3231 year ago (edited 1 year ago)

cc @will-cromar, as the change is made on PjrtComputationClient.

will-cromar
will-cromar approved these changes on 2023-06-22
will-cromar1 year ago👍 1

LGTM. Thanks!

I don't have any concerns for performance in this PR because you can simply disable XLA_STABLEHLO_COMPILE.

lsy323
lsy3231 year ago

Let's run a benchmark with / without this PR and determine performance. We can start with resnet and mnist.

Sure, will do that on ResNet and MNIST and share the results.

lsy323
lsy3231 year ago

Measured ResNet18 compile and execution time on TPU v4. StableHLO takes more time to compile as it needs to be converted back to HLO at some point in XLA PjRt Client. Execution time is similar.

Script: https://gist.github.com/lsy323/057348110ba439b58861400d35ef5ff0
Command to run:
Compile with StableHLO: PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=1 python test_stablehlo_latency.py
Compile with HLO: PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=0 python test_stablehlo_latency.py

Output (StableHLO):

root@t1v-n-617b7043-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=1 python test/stablehlo/test_stablehlo_latency.py
tensor([[-3.0977, -0.9105, -2.2717,  ...,  2.0278, -4.0764, -3.4962],
        [-3.0307, -0.8453, -2.2413,  ...,  2.0563, -4.0359, -3.5483],
        [-3.0857, -0.8055, -2.3006,  ...,  2.0095, -4.0301, -3.5824],
        [-3.0294, -0.9413, -2.2699,  ...,  2.1221, -4.1136, -3.6157]],
       device='xla:0', dtype=torch.float64, grad_fn=<AddmmBackward0>)
Metric: CompileTime
  TotalSamples: 1
  Accumulator: 04s760ms645.719us
  Percentiles: 1%=04s760ms645.719us; 5%=04s760ms645.719us; 10%=04s760ms645.719us; 20%=04s760ms645.719us; 50%=04s760ms645.719us; 80%=04s760ms645.719us; 90%=04s760ms645.719us; 95%=04s760ms645.719us; 99%=04s760ms645.719us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 006ms434.110us
  Percentiles: 1%=006ms434.110us; 5%=006ms434.110us; 10%=006ms434.110us; 20%=006ms434.110us; 50%=006ms434.110us; 80%=006ms434.110us; 90%=006ms434.110us; 95%=006ms434.110us; 99%=006ms434.110us
Counter: StableHloCompile
  Value: 1

Output (HLO):

root@t1v-n-617b7043-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU XLA_STABLEHLO_COMPILE=0 python test/stablehlo/test_stablehlo_latency.py
tensor([[ 1.5539, -1.4548, -1.3284,  ..., -0.4333, -4.3026, -3.7028],
        [ 1.7148, -1.4618, -1.4548,  ..., -0.6571, -4.3884, -3.7598],
        [ 1.6167, -1.4881, -1.3145,  ..., -0.6026, -4.2327, -3.7895],
        [ 1.6563, -1.4699, -1.3832,  ..., -0.5625, -4.3638, -3.7374]],
       device='xla:0', dtype=torch.float64, grad_fn=<AddmmBackward0>)
Metric: CompileTime
  TotalSamples: 1
  Accumulator: 04s696ms979.632us
  Percentiles: 1%=04s696ms979.632us; 5%=04s696ms979.632us; 10%=04s696ms979.632us; 20%=04s696ms979.632us; 50%=04s696ms979.632us; 80%=04s696ms979.632us; 90%=04s696ms979.632us; 95%=04s696ms979.632us; 99%=04s696ms979.632us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 006ms091.800us
  Percentiles: 1%=006ms091.800us; 5%=006ms091.800us; 10%=006ms091.800us; 20%=006ms091.800us; 50%=006ms091.800us; 80%=006ms091.800us; 90%=006ms091.800us; 95%=006ms091.800us; 99%=006ms091.800us

cc @miladm

vanbasten23
vanbasten23 commented on 2023-06-23
test/stablehlo/test_stablehlo_compile.py
19 # Run ResNet on XLA device.
20 device = xm.xla_device()
21 # materalize the fake data for test purpose
22
xm.mark_step()
vanbasten231 year ago

I wonder why we need to do mark_step. The comment says "materalize the fake data for test purpose" but we don't have a HLO graph at this point.

lsy3231 year ago👍 1

I was referencing the code in dynamo test. I this the purpose is to clear the graph from previous tests.

lsy323 lsy323 marked this pull request as ready for review 1 year ago
lsy323 lsy323 force pushed from d24dce8a to 4721aa6d 1 year ago
lsy323 Enable xla PjRt client compilation with StableHLO
b63ab577
lsy323 add XLA_STABLEHLO_COMPILE to configuration.yaml
5af47b0c
lsy323 lsy323 force pushed from 4721aa6d to 5af47b0c 1 year ago
lsy323 fix merge conflict
b79d5aa1
lsy323 dummy commit to trigger ci
f7aec233
lsy323 Revert "dummy commit to trigger ci"
aaaa4b39
lsy323 Merge branch 'master' into stablehlo-compile
6c357a6b
lsy323
lsy3231 year ago

Benchmarked on Transformer (from torch.nn.transformer), no degradation observed in execution time
script
Compile with StableHLO

Metric: CompileTime
  TotalSamples: 1
  Accumulator: 13s679ms954.216us
  Percentiles: 1%=13s679ms954.216us; 5%=13s679ms954.216us; 10%=13s679ms954.216us; 20%=13s679ms954.216us; 50%=13s679ms954.216us; 80%=13s679ms954.216us; 90%=13s679ms954.216us; 95%=13s679ms954.216us; 99%=13s679ms954.216us
Metric: ExecuteTime
  TotalSamples: 20
  Accumulator: 112ms122.061us
  ValueRate: 002ms612.975us / second
  Rate: 0.287718 / second
  Percentiles: 1%=004ms238.920us; 5%=004ms327.740us; 10%=004ms482.730us; 20%=005ms692.209us; 50%=005ms805.449us; 80%=005ms962.460us; 90%=006ms530.290us; 95%=021ms244.279us; 99%=021ms244.279us
Counter: StableHloCompile
  Value: 1

Compile with HLO

Metric: CompileTime
  TotalSamples: 1
  Accumulator: 11s715ms884.137us
  Percentiles: 1%=11s715ms884.137us; 5%=11s715ms884.137us; 10%=11s715ms884.137us; 20%=11s715ms884.137us; 50%=11s715ms884.137us; 80%=11s715ms884.137us; 90%=11s715ms884.137us; 95%=11s715ms884.137us; 99%=11s715ms884.137us
Metric: ExecuteTime
  TotalSamples: 20
  Accumulator: 129ms986.975us
  ValueRate: 002ms878.180us / second
  Rate: 0.29122 / second
  Percentiles: 1%=005ms857.130us; 5%=005ms859.580us; 10%=005ms866.300us; 20%=005ms946.980us; 50%=005ms125.240us; 80%=005ms379.290us; 90%=006ms699.240us; 95%=031ms177.398us; 99%=031ms177.398us
lsy323
lsy3231 year ago

@JackCaoG I think we can merge this one now.

JackCaoG
JackCaoG approved these changes on 2023-06-28
JackCaoG JackCaoG merged 905be9d5 into master 1 year ago
lsy323 lsy323 deleted the stablehlo-compile branch 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
Labels
Milestone