vllm
[torch.compile] add logging for compilation time
#10941
Merged

[torch.compile] add logging for compilation time #10941

youkaichao
youkaichao148 days ago
No description provided.
youkaichao compute compilation time
b9a03288
youkaichao fix one graph
c7b42481
github-actions
github-actions148 days ago

๐Ÿ‘‹ Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

๐Ÿš€

youkaichao update name
70e6b4e3
youkaichao
youkaichao148 days ago

example output (for the toy tests):

Compiling a graph for general shape takes 1.55 s

see https://buildkite.com/vllm/fastcheck/builds/9379#01939a08-6359-4659-9b3c-c07576a7efee/6465-8353

youkaichao youkaichao requested a review from WoosukKwon WoosukKwon 148 days ago
youkaichao youkaichao marked this pull request as draft 148 days ago
youkaichao add context manager
3044d17d
youkaichao dirty impl
8c0ac0ba
youkaichao use api
0f6719b9
youkaichao fi
2bce8de7
youkaichao use timing at engine level
2b65379a
youkaichao engine level logging
81914b15
youkaichao remove the call in the worker
7ab3f8bd
youkaichao revert
10c88ee8
youkaichao fix logging
6f39058d
youkaichao fix symbolic only
70e35341
youkaichao fix format
fa36de59
youkaichao
youkaichao148 days ago (edited 148 days ago)

this pr aims to add three logging information:

  1. engine-level initialization, e.g.

INFO 12-05 22:37:22 llm_engine.py:493] init engine (profile, create kv cache, warmup model) took 15.08 seconds
INFO 12-05 22:39:24 llm_engine.py:493] init engine (profile, create kv cache, warmup model) took 40.77 seconds
INFO 12-05 22:41:51 llm_engine.py:493] init engine (profile, create kv cache, warmup model) took 50.18 seconds

  1. graph compilation for every shape (including the symbolic shape compilation), e.g.

INFO 12-05 22:39:02 backends.py:55] Compiling a graph for general shape takes 14.73 s
INFO 12-05 22:41:51 backends.py:58] Compiling a graph for shape 1 takes 9.99 s

  1. aggregation of the numbers in 2, e.g.

INFO 12-05 22:39:04 monitor.py:13] graph compilation takes 14.73 s in total
INFO 12-05 22:41:51 monitor.py:13] graph compilation takes 24.758146286010742 s in total

how to read it:

the increase in 1, when using or not using torch.compile , is the total cost of torch.compile .

2 shows the cost of every shape compilation, so that users can select it according to their budget. Note that some compilation like Dynamo bytecode compilation and triton compilation are not considered here.

3 just aggregates 2.

youkaichao youkaichao marked this pull request as ready for review 148 days ago
youkaichao youkaichao requested a review from robertgshaw2-redhat robertgshaw2-redhat 148 days ago
youkaichao youkaichao requested a review from njhill njhill 148 days ago
youkaichao youkaichao requested a review from ywang96 ywang96 148 days ago
youkaichao youkaichao requested a review from comaniac comaniac 148 days ago
youkaichao youkaichao requested a review from alexm-redhat alexm-redhat 148 days ago
youkaichao youkaichao requested a review from zhuohan123 zhuohan123 148 days ago
WoosukKwon
WoosukKwon approved these changes on 2024-12-06
WoosukKwon148 days ago

Honestly, I don't have enough familiarity to the code for proper review. However, the code apparently looks OK to me. Please feel free to merge.

Conversation is marked as resolved
Show resolved
vllm/compilation/backends.py
2325 example_inputs,
2426 additional_inductor_config,
27 compilation_config: CompilationConfig,
2528
do_logging=False,
WoosukKwon148 days ago

nit:

Suggested change
do_logging=False,
do_logging: bool =False,
vllm/compilation/monitor.py
9
10
11def end_monitoring_torch_compile(compilation_config: CompilationConfig):
12
if compilation_config.level == CompilationLevel.PIECEWISE:
13
logger.info("graph compilation takes %.2f s in total",
14
compilation_config.compilation_time)
WoosukKwon148 days ago

Dumb question: Why does it print only for pw CUDA graphs?

youkaichao148 days ago

CompilationLevel.PIECEWISE is piecewise compile, not piecewise cudagraph. this is orthogonal to cudagraph.

vllm/compilation/backends.py
108120# we share the global graph pool among all the backends
109121global_graph_pool = None
110122
123
compilation_start_time = 0.0
WoosukKwon148 days ago

Maybe None instead of 0.0?

WoosukKwon148 days ago

Just wondering: can we somehow make this more robust? Since the code touching this var is scattered into different places, I feel it's error prone...

youkaichao148 days ago

that's why I want to merge some functions in the worker/executor. however, given the current code status, i don't have bandwidth to refactor those files.

WoosukKwon
WoosukKwon148 days ago

INFO 12-05 22:41:51 monitor.py:13] graph compilation takes 24.758146286010742 s in total

Please make sure to use %2f for this log (if it hasn't been fixed yet).

youkaichao Update vllm/compilation/backends.py
d87ac207
youkaichao merge code
f5021216
youkaichao
youkaichao148 days ago

@WoosukKwon thanks for the review!

youkaichao youkaichao enabled auto-merge (squash) 148 days ago
github-actions github-actions added ready
youkaichao youkaichao merged b031a455 into main 148 days ago
youkaichao youkaichao deleted the compilation_time branch 147 days ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone