-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Llava] Phi text model produces ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8])
when using past_key_values
in generate
#30809
Comments
One way to get it running is to specify the vision inputs: - new_inputs = processor(decoded, return_tensors="pt", padding=True)
+ new_inputs = processor(decoded, images=image, return_tensors="pt", padding=True) (but it's still odd why llama works and phi doesn't without the vision inputs) 👀 |
Here's a full example of it working (but inefficiently): import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "xtuner/llava-phi-3-mini-hf"
prompt = "<|user|>\n<image>\nWhat are these?<|end|>\n<|assistant|>\n"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate w/o past_key_values
output = model.generate(
**inputs,
max_new_tokens=3, # Stop early so we can test out continuing it later
do_sample=False,
return_dict_in_generate=True,
)
print(processor.decode(output.sequences[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False))
# outputs "These are two"
# Generate w/ past_key_values
continued_ids = model.generate(
output.sequences,
pixel_values=inputs.pixel_values,
attention_mask=torch.ones_like(output.sequences),
do_sample=False,
past_key_values=output['past_key_values'],
max_new_tokens=20,
)
print(processor.decode(continued_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False))
# outputs "These are two cats sleeping on a pink couch.<|end|><|end|><|end|><|end|><|endoftext|>" As I understand it, the vision encoder is still run and the inputs are still merged, even though they will be cropped out later and the past_key_values will be used. |
Hey! Llama models were the first ones to get new features like StaticCache or a method to But the true error lies in the way we handle vision language models. Right now we expand the input embedding inside modeling file by concatenating image embeddings with text embeddings. When we try to continue generate, the continued_ids = model.generate(
output.sequences,
do_sample=False,
past_key_values=output['past_key_values'],
max_new_tokens=20,
)
Yes, this way we can trigger expansion of inputs by concatenating image embeddings, so that the final shapes match with past cache. Yes, the best solution is to move inputs expansion by dummy values into the processors. I believe @amyeroberts is working on it. |
Thanks for the great explanation @zucchini-nlp! Indeed the differences between phi and llama confused me a bit (my use case is support for this in Transformers.js), so I've now taken this into account in my implementation. Although, this is still problematic for when images are passed. I'll continue looking into this. |
@zucchini-nlp this means updating LLMs to the new cache format will remove this error, correct? |
@gante no, it will not as past key values still will not match the size of the input sequence. The code will fail on identifying the correct start and end for cache_positions |
(@zucchini-nlp since this is VLM related, I'm going to leave you in charge of fixing this issue :) ) |
Oke, the draft PR is ready, let's link it here for visibility (Will be fixed by #30962) |
System Info
transformers
version: 4.38.2Who can help?
@gante (generate) @susnato (phi implementation) @younesbelkada (llava implementation)
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Following the multi-round conversation tutorial from here, I put together this minimal reproduction to show how switching Llava to use a Phi text model (instead of e.g., llama) results in an error when reusing past key values.
Running:
results in this error
Expected behavior
If you try with a llama model (e.g., here; see comments) it works correctly.
The text was updated successfully, but these errors were encountered: