Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why whisper-jax did not use my GPU? #175

Open
bk111 opened this issue Jan 3, 2024 · 3 comments
Open

why whisper-jax did not use my GPU? #175

bk111 opened this issue Jan 3, 2024 · 3 comments

Comments

@bk111
Copy link

bk111 commented Jan 3, 2024

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

instantiate pipeline in bfloat16

pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16, batch_size=16)

text = pipeline("10m.mp3")
print(text)

@nftblackmagic
Copy link

You might need to check your jax version.

I am use jax 0.4.19, it works.

!pip install jax==0.4.19
!pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@bk111
Copy link
Author

bk111 commented Jan 6, 2024

I tried to use windows11(settings/installed apps, NVIDIA driver 546.33, NVIDIA cuda11.8)(nvidia-smi, CUDA 12.3)-wsl2-Ubuntu-docker container-nvidia/cuda image, (https://hub.docker.com/r/nvidia/cuda/tags?page=2&name=11.8), no luck.

Do I need "jax[cuda12_pip]==0.4.19" or "jax[cuda11_pip]==0.4.19"? and how about container version?

@bk111
Copy link
Author

bk111 commented Jan 7, 2024

You might need to check your jax version.

I am use jax 0.4.19, it works.

!pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

now, gpu got working. but it's slow. a mp3 of 10 minutes spent more than 300s.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants