vllm
[V1] Support VLMs with fine-grained scheduling
#9871
Merged

[V1] Support VLMs with fine-grained scheduling #9871

WoosukKwon merged 3 commits into main from v1-vlm-sched
WoosukKwon
WoosukKwon189 days ago (edited 181 days ago)👍 5

This PR implements the basic vision language model support in V1.

Motivation

Multi-modal inputs are difficult to deal with because they often have complex (or non-trivial) dependencies. For example, the model can take a prompt with interleaved texts and images like

Screenshot 2024-11-07 at 10 05 10 PM

Here, different colors represent different types of dependencies:

  • Red: Can be computed independently of each other
  • Yellow: Depends on Image Embedding 0
  • Green: Depends on Image Embedding 1

In V0, we didn't consider those dependencies. V0 circumvented it by always processing the entire prompt (all images & text) at once. However, this is not desirable, since it doesn't fit with other optimizations such as chunked prefills and prefix caching.

Proposal

To address this limitation, this PR proposes to make the V1 scheduler consider & track these dependencies explicitly, and do flexible & fine-grained scheduling based on it. One example can be like following:
Screenshot 2024-11-07 at 10 06 17 PM

  1. The scheduler leverages chunked prefills for the decoder inputs, so that TPOT is under control.
  2. Furthermore, the scheduler ensures that not too many images are processed by the vision encoder in the same step, because this can cause a spike in TTFT/TPOT.
  3. This fine-grained scheduling will also allow using prefix caching for VLMs, although it's not implemented in this PR.

Implementation

  • The scheduler has “encoder budget” (e.g., number of input image tokens in ViT) and “decoder budget” (number of input tokens).
  • The scheduler explicitly schedules the encoder and decoder inputs, considering the input dependencies.
    • The vision encoder and LLM decoder will live in the same GPU.
    • In every step, the model runner will first (optionally) run the vision encoder, and then run the LLM decoder possibly with the output of the encoder.
  • The model runner caches the encoder outputs (e.g., image embeddings) in encoder cache on GPU until the entire tensor is consumed by the decoder.
    • We should limit the maximum size of the cache, since the encoder outputs can be large. This will work as a scheduling constraint in the scheduler.

Limitations

  • Currently, the design only consider Llava-style model architectures (e.g., Pixtral, Molmo). It didn't consider other model architectures like multi-modal Llama.
  • Currently, the implementation in the PR only supports Llava v1.5 and Phi3v because of the necessary changes in model's input processor. Support for other models will be implemented in a followup PR.
  • Currently, the encoder cache is just a pool of tensors. For more precise memory management, we need to store it in paged memory, just like the paged KV cache. I leave this as future work.
  • Currently, the scheduling logic for encoder inputs is a bit hacky because of some limitations on the V1 model runner. This needs to be further refined in the next PR.

Misc

To reduce the conflicts, I reverted back the changes in detokenizer. Plus, the MM input mapper will run on the same process as the engine (scheduler) for now. We will move it to a separate process later.

github-actions
github-actions189 days ago

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

ywang96 ywang96 assigned ywang96 ywang96 189 days ago
alexm-redhat
alexm-redhat commented on 2024-10-31
alexm-redhat189 days ago

The new code looks great. Also the performance should be better. Some nit comments

vllm/v1/core/scheduler.py
alexm-redhat189 days ago

What's the meaning of this num_new_tokens update here? For example, start_pos can be < num_computed_tokens, and then the result may be potentially negative?

ywang96184 days ago (edited 184 days ago)

I don't think it's possible to have start_pos < num_computed_tokens here: This is because num_computed_tokens are tokens already processed, which means if there were an image with start_pos < num_computed_tokens, it should have been already processed in the previous iteration (either stored in KV cache, or cached in encoder cache).

If I understand correctly, the point of this update is that if we cannot run encoder here, then we want to stop at exactly before where the first encoder position is, to run decoder only processing for this current iteration. However, I think it is possible to have start_pos == num_computed_tokens for a running request? (e.g, the first image token in a placeholder is exactly the first scheduled token, but the cache cannot allocate).

WoosukKwon184 days ago

It's possible when prefix caching is enabled (whilst we currently don't support prefix caching for VLMs).

WoosukKwon184 days ago

we want to stop at exactly before where the first encoder position is, to run decoder only processing for this current iteration.

Exactly.

Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
323 if num_encoder_tokens > encoder_budget:
324 # Cannot schedule because the encoder budget is exhausted.
325 # NOTE(woosuk): We assume that the encoder tokens should be
326
# processed altogether, as the model usually uses the
alexm-redhat189 days ago

comment cuts here...

WoosukKwon183 days ago

nice catch! Completed the sentence.

Just wanted to say that the encoder usually uses bidirectional attention, which requires every part of the input to be processed together.

vllm/v1/core/scheduler.py
alexm-redhat189 days ago👍 1

Seems like a code duplication with the running case. Maybe the duplication can be avoided somehow.

ywang96184 days ago

+1

WoosukKwon177 days ago

The code was simplified a bit. I found it difficult to further refactor, since it's only 5 lines of code, and it involves updating the local variables like scheduled_encoder_inputs and encoder_budget. The code looks ok to me. WDYT?

Conversation is marked as resolved
Show resolved
vllm/v1/request.py
72 @property
73 def num_encoder_inputs(self) -> int:
74 # This method should be called only after the mm input mapper
75
# has been applied.
alexm-redhat189 days ago

can assert here if mm_positions is not None (to avoid accidental call before mapper)

WoosukKwon177 days ago

The comment became irrelevant after #8346. Removed.

Conversation is marked as resolved
Show resolved
vllm/v1/worker/gpu_model_runner.py
308375 self._update_states(scheduler_output)
376
377 # Run the encoder.
378
self._excute_encoder(scheduler_output)
alexm-redhat189 days ago

nit: => execute_encoder

WoosukKwon185 days ago

Nice catch! Fixed.

vllm/v1/worker/gpu_model_runner.py
alexm-redhat189 days ago

nit: A quick doc for this start/end indices computation would be helpful here.

WoosukKwon177 days ago

Added some comments above to help understand the logic.

Conversation is marked as resolved
Show resolved
vllm/v1/worker/gpu_model_runner.py
379 encoder_outputs = self._gather_encoder_outputs(scheduler_output)
380
381 # Prepare the decoder inputs.
309382
inputs = self._prepare_inputs(scheduler_output)
alexm-redhat189 days ago👍 1

Good separation between encoder parts and decoder parts. All encoder pieces (preprocessors/vit) are done before the prepare_inputs call which is nice!

WoosukKwon184 days ago

Thanks!

mergify
mergify189 days ago

This pull request has merge conflicts that must be resolved before it can be
merged. @WoosukKwon please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify mergify added needs-rebase
mergify mergify removed needs-rebase
ywang96
ywang96 commented on 2024-11-05
ywang96184 days ago (edited 184 days ago)

Sorry for the very delayed review - left some comments!

FWIW - I did some mini benchmark on this branch vs V0 on 1 x A100-80G.

Command: python vllm/examples/offline_inference_vision_language.py --num-prompts 1000

V0:

1000/1000 [01:33<00:00, 10.69it/s, est. speed input: 6369.87 toks/s, output: 682.44 toks/s]

V1 with this PR and default budget & cache:

1000/1000 [01:01<00:00, 16.13it/s, est. speed input: 9614.21 toks/s, output: 1029.49 toks/s]

V1 with encoder budget and cache size = 576 (This should be more or less equivalent to V1 with previous design of VLM)

1000/1000 [01:15<00:00, 13.18it/s, est. speed input: 7856.67 toks/s, output: 841.03 toks/s]
vllm/v1/core/scheduler.py
ywang96184 days ago

See my other comment on when num_new_tokens can be 0 for a running sequence.

comaniac183 days ago

I also fixed this (for the above decoder tokens) in the prefix caching PR. Also to clarify the semantic of num_new_tokens:

  • Before calling _schedule_encoder_inputs, num_new_tokens would be the text tokens as well as image tokens (placeholder).
  • After calling _schedule_encoder_inputs, num_new_tokens may be the same as before if encoder budget allows; otherwise it would be reduced to only include text tokens.

Is this understanding correct?

WoosukKwon177 days ago

@comaniac Yes, correct. When the encoder cache or budget is insufficient, num_new_tokens can decrease up to the point just before the encoder input (e.g., image placeholder).

Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
158 scheduled_encoder_inputs[request.request_id] = (
159 encoder_inputs_to_schedule)
160 # Update the encoder budget and allocate the encoder cache.
161
for i in encoder_inputs_to_schedule:
162
encoder_budget -= request.get_num_encoder_tokens(i)
ywang96184 days ago

Can't we get this updated encoder budget directly from _schedule_encoder_inputs?

WoosukKwon177 days ago

Good suggestion! Fixed.

Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
6166 self.running_reqs_data: Dict[str, RunningRequestData] = {}
6267
63 def schedule(self) -> "SchedulerOutput":
64 scheduled_new_reqs: List[Request] = []
65 scheduled_resumed_reqs: List[Request] = []
66 scheduled_running_reqs: List[Request] = []
67 preempted_reqs: List[Request] = []
68 # Encoder-related.
69
# NOTE(woosuk): Here, "encoder" includes the vision encoder. Currently,
70
# we assume that the encoder also has the Transformer architecture
71
# (e.g., ViT).
ywang96184 days ago
Suggested change
# NOTE(woosuk): Here, "encoder" includes the vision encoder. Currently,
# we assume that the encoder also has the Transformer architecture
# (e.g., ViT).
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if required). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
WoosukKwon183 days ago

Thanks for the suggestion! Fixed.

Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
ywang96184 days ago (edited 184 days ago)

I wonder if it makes sense to evict mm_position that has been fully processed from request.mm_positions at every step because it seems that we will be unnecessarily iterating over them a lot (here and later in _gather_encoder_outputs)

This is not relevant at all for single-image models but I do think it'll matter for multi-image and video models.

WoosukKwon183 days ago

Good point. Unfortunately, as we discussed offline, all information needs to be kept because of preemption & re-computation.

I think we can optimize it in the future when we find this becomes non-negligible overheads.

WoosukKwon
WoosukKwon184 days ago

@ywang96 Thanks for the review!

QQ: How did you measure the perf of V1 without this PR?

ywang96
ywang96184 days ago👍 1

@ywang96 Thanks for the review!

QQ: How did you measure the perf of V1 without this PR?

I have updated my original review comment - PTAL!

mergify
mergify183 days ago

This pull request has merge conflicts that must be resolved before it can be
merged. @WoosukKwon please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify mergify added needs-rebase
alexm-redhat
alexm-redhat183 days ago

FYI,

I did a quick performance benchmark for microsoft/Phi-3.5-vision-instruct when you have a separate process for mm_mapper (on the old version of this PR) and when you don't have a separate process. Results below show that separate process has large TTFT overhead, even when the RPS goes up (which is a bit surprising) - I think it is related to pickle/socket overheads most likely. I did some manual timings specifically on the roundtrip times to the separate process and I saw that mm_mapper is 5X slower with separate process than simply running directly.

RPS V0 - TTFT V1 - TTFT (separate process mm_mapper)
1 67.05 127.99
5 73.23 143.1
10 84.28 190.66
     
RPS V0 - TPOT V1 - TPOT (separate process mm_mapper)
1 14.27 14.44
5 17.59 18.89
10 25.47 27.8

When there is no separate process, the performance looks much better:

RPS V0 - TTFT V1 - TTFT (direct mm_mapper)
1 67.05 69.91
5 73.23 78.47
10 84.28 89.23
     
RPS V0 - TPOT V1 - TPOT (direct mm_mapper)
1 14.27 13.10
5 17.59 14.19
10 25.47 16.17

The commands are used are:

server: vllm serve microsoft/Phi-3.5-vision-instruct --trust-remote-code --max-model-len 4096 --enforce-eager --disable-async-output-proc

client: python benchmarks/benchmark_serving.py --backend openai-chat --base-url http://0.0.0.0:8000/v1 --endpoint /chat/completions --model microsoft/Phi-3.5-vision-instruct --dataset-path lmms-lab/LLaVA-OneVision-Data --dataset-name hf --hf-subset "chart2text(cauldron)" --hf-split train --num_prompts=100 --request-rate 5

comaniac
comaniac183 days ago

Thanks for the benchmarking. Could you also benchmark throughput? I suppose the benefit of separate processes should be more obvious in throughput instead of latency, as long as we pipeline mm_mapper well?

comaniac
comaniac commented on 2024-11-06
Conversation is marked as resolved
Show resolved
vllm/v1/core/encoder_cache_manager.py
3from vllm.v1.request import Request
4
5
6
class EncoderCacheManager:
comaniac183 days ago👍 1

This looks pretty clean! We could consider supporting embedding sharing later.

vllm/v1/core/scheduler.py
comaniac183 days ago

In what situation this limitation would hurt the performance?

vllm/v1/core/scheduler.py
223294 self.running_reqs_data[request.request_id] = req_data
224295 return req_data
225296
297
def _schedule_encoder_inputs(
298
self,
299
request: Request,
300
num_computed_tokens: int,
301
num_new_tokens: int,
302
encoder_budget: int,
303
) -> Tuple[List[int], int]:
comaniac183 days ago

Please docstring this function for readability.

WoosukKwon177 days ago

Added. Thanks for the suggestion.

ywang96
ywang96 commented on 2024-11-06
vllm/model_executor/models/llava.py
495517 """
496518 if intermediate_tensors is not None:
497519 inputs_embeds = None
498 else:
499 image_input = self._parse_and_validate_image_input(**kwargs)
500 if image_input is not None:
501 vision_embeddings = self._process_image_input(image_input)
502 inputs_embeds = self.language_model.model.get_input_embeddings(
503 input_ids)
504
505 inputs_embeds = merge_multimodal_embeddings(
506 input_ids, inputs_embeds, vision_embeddings,
507 self.config.image_token_index)
508 else:
509 inputs_embeds = self.language_model.model.get_input_embeddings(
510 input_ids)
511
512 # always pass the input via `inputs_embeds`
513 # to make sure the computation graph is consistent
514 # for `torch.compile` integration
515 input_ids = None
520
elif inputs_embeds is None:
521
vision_embeddings = self.process_mm_inputs(**kwargs)
522
# always pass the input via `inputs_embeds`
523
# to make sure the computation graph is consistent
524
inputs_embeds = self.get_inputs_embeds(input_ids,
ywang96183 days ago (edited 182 days ago)

If we're putting the encoder forward pass and embedding merge at model_runner level, then I don't think the code here is needed? (Is it possible for inputs_embeds to be None here when there's multimodal data in the request? If not, we just need to call embed_tokens here to get the text embeddings)

nvm - I see that it's needed here to be compatible with V0 - I will add a note accordingly in my PR to indicate that this needs to be cleaned up after we fully deprecate v0

mergify mergify removed needs-rebase
WoosukKwon WoosukKwon marked this pull request as ready for review 181 days ago
WoosukKwon WoosukKwon changed the title [V1] Support VLMs [V1] Support VLMs with fine-grained scheduling 181 days ago
ywang96
ywang96 commented on 2024-11-11
Conversation is marked as resolved
Show resolved
vllm/model_executor/models/phi3v.py
687 def get_inputs_embeds(
688 self,
689 input_ids: torch.Tensor,
690
vision_embeddings: Optional[NestedTensors],
ywang96179 days ago
Suggested change
vision_embeddings: Optional[NestedTensors],
vision_embeddings: Optional[NestedTensors] = None,
WoosukKwon177 days ago

Good catch. Fixed.

mergify
mergify178 days ago

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify mergify added needs-rebase
mergify mergify removed needs-rebase
mergify
mergify178 days ago

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify mergify added needs-rebase
mergify mergify removed needs-rebase
WoosukKwon WoosukKwon requested a review from ywang96 ywang96 177 days ago
WoosukKwon
WoosukKwon177 days ago

@ywang96 Addressed comments. PTAL.

WoosukKwon
WoosukKwon177 days ago

@ywang96 This PR actually requires adding get_input_embeddings method to all models (while I only added it to llama, opt, llava, and phi3v in this PR), because it know executes the model's embedding layer and the other parts separately.

If we don't want to add this method to the text models, we can use self.model.model.get_input_embeddings instead, while it looks a bit hacky.

ywang96
ywang96 approved these changes on 2024-11-12
ywang96177 days ago

@WoosukKwon Overall looks good to me! I left a few more comments mainly around code clarifications so please take a look.

Conversation is marked as resolved
Show resolved
vllm/model_executor/models/llava.py
461 input_ids: torch.Tensor,
462 vision_embeddings: Optional[NestedTensors] = None,
463 ) -> torch.Tensor:
464
inputs_embeds = self.language_model.model.get_input_embeddings(
ywang96177 days ago

If we're going to add get_input_embeddings to all models, then let's do

Suggested change
inputs_embeds = self.language_model.model.get_input_embeddings(
inputs_embeds = self.language_model.get_input_embeddings(
Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
361 break
362 if num_encoder_tokens > encoder_budget:
363 # The encoder budget is exhausted. We can only schedule the
364
# decoder tokens just before the encoder input.
ywang96177 days ago
Suggested change
# decoder tokens just before the encoder input.
# decoder tokens up until the encoder input.
Conversation is marked as resolved
Show resolved
vllm/v1/worker/gpu_model_runner.py
ywang96177 days ago

I would add a comment here about this encoder_outputs here that it should be either

  • A tensor of shape [num_images, feature_size, hidden_size] in case when feature_size is fixed across all images.
  • or a list of length num_images, of tensors of shape (feature_size, hidden_size) in case when feature size is dynamic depending on input images.
WoosukKwon177 days ago

Great idea. Added the comment. Thanks for the clarification!

Conversation is marked as resolved
Show resolved
vllm/v1/worker/gpu_model_runner.py
389 # num_computed_tokens + num_scheduled_tokens) and
390 # [start_pos, start_pos + num_encoder_tokens)
391 if start_pos >= num_computed_tokens + num_scheduled_tokens:
392
# The encoder input is not needed in this step.
ywang96177 days ago
Suggested change
# The encoder input is not needed in this step.
# The encoder output is not needed in this step.
Conversation is marked as resolved
Show resolved
vllm/v1/worker/gpu_model_runner.py
392 # The encoder input is not needed in this step.
393 break
394 if start_pos + num_encoder_tokens <= num_computed_tokens:
395
# The encoder input is already computed and stored
ywang96177 days ago
Suggested change
# The encoder input is already computed and stored
# The encoder output is already processed and stored
Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
396 start_pos = request.mm_positions[input_id]["offset"]
397 num_tokens = request.mm_positions[input_id]["length"]
398 if start_pos + num_tokens <= request.num_computed_tokens:
399
# The encoder input is already computed and stored
ywang96177 days ago
Suggested change
# The encoder input is already computed and stored
# The encoder output is already processed and stored
Conversation is marked as resolved
Show resolved
vllm/v1/worker/gpu_model_runner.py
ywang96177 days ago

I think it's worth a note that model executable will always takes inputs_embeds as input.

WoosukKwon177 days ago

Good point. Added a comment.

Conversation is marked as resolved
Show resolved
vllm/v1/core/scheduler.py
ywang96177 days ago (edited 177 days ago)

This means we always initialize an encoder cache manager regardless of the model type (text-only or multimodal).

I don't think this is an issue technically because for text-only models there's no multimodal data during profiling (and thus no memory usage for encode cache), but we should probably make a note here to indicate so, otherwise it might be misleading that we're reserving space for text-only models.

WoosukKwon177 days ago

Good point. Actually, we don't even preallocate the space for multi-modal models, because the profiling logic in V1 model runner is not sophisticated at all. I will just mark it as TODO for now.

WoosukKwon177 days ago👍 1

Added a comment in 361cadf. PTAL.

ywang96
ywang96177 days ago (edited 177 days ago)

@WoosukKwon Everything looks good to me now - can you merge with main after #10272 is merged for the test fix? After that we can merge this.

WoosukKwon DCO
04edd1c6
WoosukKwon WoosukKwon force pushed to 04edd1c6 177 days ago
WoosukKwon WoosukKwon added ready
WoosukKwon WoosukKwon enabled auto-merge (squash) 177 days ago
WoosukKwon Fix for CI
0da5df8a
WoosukKwon fix
07ef65c8
WoosukKwon WoosukKwon merged bbd3e869 into main 177 days ago
WoosukKwon WoosukKwon deleted the v1-vlm-sched branch 177 days ago
petersalas
petersalas commented on 2024-11-15
vllm/v1/engine/core.py
9397 """Add request to the scheduler."""
9498
9599 req = Request.from_engine_core_request(request)
100
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
101
# take 10-50 ms, which can cause a spike in the latency. We should
102
# consider moving this to a separate thread.
103
if req.mm_data:
104
req.mm_inputs = self.mm_input_mapper.process_inputs(
105
req.mm_data, req.mm_processor_kwargs)
petersalas174 days ago👍 1

One very nice property of V0 + #8348 is that the input mapper can be skipped entirely if the multimodal item is covered by the prefix cache (in our use case with Ultravox we can have many audio chunks in each inference). Not sure if that's practical to preserve in V1?

Login to write a write a comment.

Login via GitHub

Assignees
Labels
Milestone