transformers
Enable fx tracing for Mistral
#30209
Merged

Enable fx tracing for Mistral #30209

zucchini-nlp
zucchini-nlp1 year ago

What does this PR do?

Fixes #30083. As per title enables tracing for Mistral. Apparently Mistral was already traceable for "sdpa" attention, as it is similar to Llama which is already working. I enabled for "eager" attention also, which failed because mistral uses "sliding window" here

Tests passing:

tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_torch_fx
tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_torch_fx_output_loss
zucchini-nlp tracing for mistral
c609326c
zucchini-nlp typo
84fde606
zucchini-nlp zucchini-nlp requested a review from michaelbenayoun michaelbenayoun 1 year ago
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

michaelbenayoun
michaelbenayoun1 year ago

Some other modeling files are based on test_modeling_mistral.py. You need to update them as well. It is easy to do it, can you please run: make fix-copies? It should basically enable fx symbolic tracing for other models that should support it.

zucchini-nlp fix copies
97923361
zucchini-nlp
zucchini-nlp1 year ago

@michaelbenayoun Done! Fix copies added tracing for MoE models also, which was a bit unexpected. Anyway, I just removed a line with dynamic control flow from MoE models, and checked that it was not necessary (even if top-x is an empty tensor)

michaelbenayoun
michaelbenayoun approved these changes on 2024-04-15
michaelbenayoun1 year ago

LGTM on my side.
Let's see what @ArthurZucker or @amyeroberts have to say about the top_x change.

michaelbenayoun michaelbenayoun requested a review from amyeroberts amyeroberts 1 year ago
michaelbenayoun michaelbenayoun requested a review from ArthurZucker ArthurZucker 1 year ago
amyeroberts
amyeroberts commented on 2024-04-16
amyeroberts1 year ago

Thanks for working on this!

If you invoke the case when top_x.shape[0] == 0 e.g. by setting

idx, top_x = torch.where(torch.zeros_like(expert_mask[expert_idx]))

in the lines above, do this still work in the tracing and non-tracing case?

zucchini-nlp
zucchini-nlp1 year ago👍 1

@amyeroberts Yes, for me it is working fine for me when it's empty tensor for 'top_x'

amyeroberts
amyeroberts approved these changes on 2024-04-16
amyeroberts1 year ago😕 1

Thanks for adding and confirming top_x behaviour!

michaelbenayoun
michaelbenayoun1 year ago

So, just confirming: we can merge with the removing of the top_x line?

ArthurZucker
ArthurZucker commented on 2024-04-17
ArthurZucker1 year ago👍 1

Cool! yeah let's remove it if that still produces correct behaviour and supports fx!

ArthurZucker
ArthurZucker approved these changes on 2024-04-17
zucchini-nlp
zucchini-nlp1 year ago

Merging now, since removal of "top_x" is approved

zucchini-nlp zucchini-nlp merged 304c6a1e into main 1 year ago
bozheng-hit
bozheng-hit1 year ago👀 1

This PR introduces a bug for Qwen2MoE GPTQ models, maybe revert it for the modeling_qwen2_moe.py file? @ArthurZucker

ArthurZucker
ArthurZucker1 year ago

? without a reproducer and a stack trace shows the error?

bozheng-hit
bozheng-hit1 year ago

? without a reproducer and a stack trace shows the error?

The code to reproduce to error is here:

from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4")

prompt = "Give me a short introduction to large language model."
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(
    model_inputs.input_ids,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

The error information is here, and the model successfully generates after I revert the change for modeling_qwen2_moe.py.

Traceback (most recent call last):
  File "/home/data/roy.zb/workspace/test_auto_gptq.py", line 23, in <module>
    generated_ids = model.generate(
  File "/cpfs01/shared/public/xingzhang.rxz/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/data/roy.zb/workspace/transformers/src/transformers/generation/utils.py", line 1656, in generate
    result = self._sample(
  File "/home/data/roy.zb/workspace/transformers/src/transformers/generation/utils.py", line 2819, in _sample
    outputs = self(
  File "/cpfs01/shared/public/xingzhang.rxz/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/data/roy.zb/workspace/transformers/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1355, in forward
    outputs = self.model(
  File "/cpfs01/shared/public/xingzhang.rxz/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/data/roy.zb/workspace/transformers/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1224, in forward
    layer_outputs = decoder_layer(
  File "/cpfs01/shared/public/xingzhang.rxz/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/data/roy.zb/workspace/transformers/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 934, in forward
    hidden_states = self.mlp(hidden_states)
  File "/cpfs01/shared/public/xingzhang.rxz/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/data/roy.zb/workspace/transformers/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 856, in forward
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
ArthurZucker
ArthurZucker1 year ago

This seems to be using GPTQ and quantisation. Can you open a separate issue and ping @younesbelkada and @SunMarc

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone