Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to stack tensors from different devices in _pad_to_max_length in Whisper batched inference #30223

Closed
2 of 4 tasks
cifkao opened this issue Apr 12, 2024 · 2 comments · Fixed by #30787
Closed
2 of 4 tasks
Assignees
Labels

Comments

@cifkao
Copy link
Contributor

cifkao commented Apr 12, 2024

This issue seems to be due to the following line, added in #29065 to fix #29036, but the fix doesn't work with batched inference on GPU/MPS because the tensor is on the wrong device:

sequences.append(torch.tensor([]))

System Info

  • transformers version: 4.40.0.dev0 bf9a7ab
  • Platform: macOS-14.2.1-arm64-arm-64bit
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@ylacombe @sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from datasets import Audio, load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import numpy as np

model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-tiny", torch_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model.to("mps")

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[:8]["audio"]
audio = [x["array"] for x in audio]
audio[0][:] = np.random.normal(scale=0.05, size=audio[0].shape)
inputs = processor(
    audio,
    return_tensors="pt",
    truncation=False,
    padding="longest",
    return_attention_mask=True,
    sampling_rate=16_000,
)
inputs = inputs.to(model.device, torch.float16)

result = model.generate(
    **inputs,
    no_speech_threshold=0.2,
    logprob_threshold=0.0,
    temperature=(0.0,),
    task="transcribe",
    language="fr",
)
decoded = processor.batch_decode(
    result, skip_special_tokens=False, decode_with_timestamps=True
)
print(decoded)
Traceback (most recent call last):
  File "/Users/ondra/bordel/test-hf/repr_batch_device_issue.py", line 27, in <module>
    result = model.generate(
  File "/Users/ondra/mambaforge/envs/transformers/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py", line 730, in generate
    sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right")
  File "/Users/ondra/mambaforge/envs/transformers/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py", line 153, in _pad_to_max_length
    sequences = torch.stack(sequences, dim=0)
RuntimeError: torch.cat(): all input tensors must be on the same device. Received cpu and mps:0

Expected behavior

No error

@huggingface huggingface deleted a comment from github-actions bot May 13, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @sanchit-gandhi @ylacombe

@ylacombe
Copy link
Collaborator

Thanks for the nice catch @cifkao, I've opened #30787 to address this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants