Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llava] Phi text model produces ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8]) when using past_key_values in generate #30809

Open
2 of 4 tasks
xenova opened this issue May 14, 2024 · 8 comments

Comments

@xenova
Copy link
Contributor

xenova commented May 14, 2024

System Info

  • transformers version: 4.38.2
  • Platform: Linux-6.1.58+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.8.3 (cpu)
  • Jax version: 0.4.26
  • JaxLib version: 0.4.26
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@gante (generate) @susnato (phi implementation) @younesbelkada (llava implementation)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Following the multi-round conversation tutorial from here, I put together this minimal reproduction to show how switching Llava to use a Phi text model (instead of e.g., llama) results in an error when reusing past key values.

Running:

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load model and processor

# THIS WORKS
# model_id = "Xenova/tiny-random-LlavaForConditionalGeneration"
# model = LlavaForConditionalGeneration.from_pretrained(model_id)

# THIS DOESN'T WORK
model_id = "Xenova/tiny-random-LlavaForConditionalGeneration_phi"
model = LlavaForConditionalGeneration.from_pretrained(model_id, attn_implementation="eager")

processor = AutoProcessor.from_pretrained(model_id)

# Define inputs
prompt = "<image>Hi"
url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png?download=true"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image,
                   return_tensors="pt", padding=True)

# Generate w/o past_key_values
output = model.generate(
  **inputs,
  max_new_tokens=3,
  return_dict_in_generate=True,
  do_sample=False,
)

decoded = processor.batch_decode(
    output["sequences"], skip_special_tokens=False)

# Prepare new inputs
new_inputs = processor(decoded, return_tensors="pt", padding=True)

# Generate w/ past_key_values
generate_ids = model.generate(
    **new_inputs,
    do_sample=False,
    past_key_values=output['past_key_values'],
    max_new_tokens=20,
)
print(f'{generate_ids=}')

decoded2 = processor.batch_decode(
    generate_ids, skip_special_tokens=False)
print(f'{decoded2=}')

results in this error

Traceback (most recent call last):
  File "/content/transformers.js/../test.py", line 39, in <module>
    generate_ids = model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1544, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2404, in greedy_search
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llava/modeling_llava.py", line 469, in forward
    outputs = self.language_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 1046, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 925, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 666, in forward
    attn_outputs, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 375, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8])

Expected behavior

If you try with a llama model (e.g., here; see comments) it works correctly.

@xenova
Copy link
Contributor Author

xenova commented May 14, 2024

One way to get it running is to specify the vision inputs:

- new_inputs = processor(decoded, return_tensors="pt", padding=True)
+ new_inputs = processor(decoded, images=image, return_tensors="pt", padding=True)

(but it's still odd why llama works and phi doesn't without the vision inputs) 👀

@xenova
Copy link
Contributor Author

xenova commented May 14, 2024

Here's a full example of it working (but inefficiently):

import requests
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

model_id = "xtuner/llava-phi-3-mini-hf"

prompt = "<|user|>\n<image>\nWhat are these?<|end|>\n<|assistant|>\n"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).to(0)

processor = AutoProcessor.from_pretrained(model_id)

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

# Generate w/o past_key_values
output = model.generate(
    **inputs,
    max_new_tokens=3, # Stop early so we can test out continuing it later
    do_sample=False,
    return_dict_in_generate=True,
)
print(processor.decode(output.sequences[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False))
# outputs "These are two"

# Generate w/ past_key_values
continued_ids = model.generate(
    output.sequences,
    pixel_values=inputs.pixel_values,
    attention_mask=torch.ones_like(output.sequences),
    do_sample=False,
    past_key_values=output['past_key_values'],
    max_new_tokens=20,
)
print(processor.decode(continued_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False))
# outputs "These are two cats sleeping on a pink couch.<|end|><|end|><|end|><|end|><|endoftext|>"

As I understand it, the vision encoder is still run and the inputs are still merged, even though they will be cropped out later and the past_key_values will be used.

@zucchini-nlp
Copy link
Member

Hey!

Llama models were the first ones to get new features like StaticCache or a method to _update_causal_mask based on the input tensor. That is why the code fails for Phi and not Llama.

But the true error lies in the way we handle vision language models. Right now we expand the input embedding inside modeling file by concatenating image embeddings with text embeddings. When we try to continue generate, the past key values hold tensors of much larger seq length than the input_text we feed here, and therefore shapes of keys/values do not match the shape of the attention mask.

continued_ids = model.generate(
    output.sequences,
    do_sample=False,
    past_key_values=output['past_key_values'],
    max_new_tokens=20,
)

One way to get it running is to specify the vision inputs:

Yes, this way we can trigger expansion of inputs by concatenating image embeddings, so that the final shapes match with past cache.

Yes, the best solution is to move inputs expansion by dummy values into the processors. I believe @amyeroberts is working on it.

@xenova
Copy link
Contributor Author

xenova commented May 15, 2024

Thanks for the great explanation @zucchini-nlp! Indeed the differences between phi and llama confused me a bit (my use case is support for this in Transformers.js), so I've now taken this into account in my implementation.

Although, this is still problematic for when images are passed. I'll continue looking into this.

@gante
Copy link
Member

gante commented May 16, 2024

@zucchini-nlp this means updating LLMs to the new cache format will remove this error, correct?

@zucchini-nlp
Copy link
Member

@gante no, it will not as past key values still will not match the size of the input sequence. The code will fail on identifying the correct start and end for cache_positions

@gante
Copy link
Member

gante commented May 29, 2024

(@zucchini-nlp since this is VLM related, I'm going to leave you in charge of fixing this issue :) )

@zucchini-nlp
Copy link
Member

Oke, the draft PR is ready, let's link it here for visibility

(Will be fixed by #30962)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants