Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper JAX is not faster than Whisper in colab GPU environment. #152

Open
bianxg opened this issue Oct 23, 2023 · 4 comments
Open

Whisper JAX is not faster than Whisper in colab GPU environment. #152

bianxg opened this issue Oct 23, 2023 · 4 comments

Comments

@bianxg
Copy link

bianxg commented Oct 23, 2023

Whisper JAX is not faster than Whisper in colab T4 GPU environment. Why?
I tested with a 841 seconds long audio file. The Whisper JAX used 182 seconds and Whisper used only 148 seconds.( Both use small model)

Please reference the Whisper JAX test code:
https://drive.google.com/file/d/1T9sGsOS4md5169jAnSpQX_tHGbS4yFEC/view?usp=sharing

@r2d209git
Copy link

I have same question

@WasamiKirua
Copy link

WasamiKirua commented Nov 27, 2023

not only on colab but also on consumer hardware. I am able to run the whisper medium on my 8 VRAM GPU with no issue but using Whisper Jax i have no idea why i need to run it in dtype float16 to do not end up with OOM error. Is there a logic explanation for this ?

@sanchit-gandhi
Copy link
Owner

Hey @bianxg - it looks like you're measuring the compilation time, which is supposed to be slow. Any subsequent calls to the pipeline will be much faster since we leverage the compiled function. You can see this in action in this Kaggle notebook: https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu

@RezaTokhshid
Copy link

@bianxg @r2d209git @WasamiKirua
I think other than the compile time, the problem is that we are all looking at the number from the demo. That number is only one the forward pass which @sanchit-gandhi has updated. I'm getting the same number as the demo on TPUs however, where it gets really slow is the post processing step where [_decode_asr](https://github.com/huggingface/transformers/blob/28de2f4de3f7bf5dbe995237b039e95b446038c3/src/transformers/models/whisper/tokenization_whisper.py#L882) is called and that's just supper slow.

Has any one had luck getting better results?

@sanchit-gandhi Any tips on how that was optimized on the demo? I'm at around 11s where It feels like 1-3 seconds on the huggingface demo you have.

Some benchmarks on my side:
on a T4 GPU

transcription 111.91071248054504
post-processing 12.171269178390503
Wall time: 2min 11s
on V3-8 TPU
transcription 3.624922275543213
post-processing 12.646348237991333
Wall time: 22.1 s

As you can see same audio, same post processing time, huge transcription boost!

Help on how to lower that 12seconds is much appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants