-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AttributeError: module 'jax.random' has no attribute 'KeyArray' #2800
Comments
Getting similar error, while loading session. 1.5 Detected |
Same problem, did anyone solve it? |
Same problem here, problem with jax |
Also seeing the same. AttributeError: module 'jax.random' has no attribute 'KeyArray' |
Same problem here. |
Same problem here! :( Training the UNet... |
Ditto |
jax.random.KeyArray was removed in JAX v0.4.24. current diffusers only work with JAX v0.4.23 or earlier. add under requirements: !pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html |
Worked like a charm. Thank you |
Can someone please help me with where specifically I need to put !pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html I am very new to all this, muddled my way through the colab workbook hit the error and came here. Thank you in advance! |
You have to click "show code" in the dependencies cell. That made it work for me :) |
Legend thank you will try it now :D Edit - works a treat! |
Yeah, just started having this problem. DjNastyMagic fix worked. |
I applied fix by DJNastyMagic. Training the UNet... |
ERROR INSTALLING CUDA aceback (most recent call last): |
I encountered the error while training the UNet.
Training the UNet...
Traceback (most recent call last):
File "/content/diffusers/examples/dreambooth/train_dreambooth.py", line 18, in
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/init.py", line 27, in
from .pipelines import OnnxRuntimeModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/init.py", line 18, in
from .dance_diffusion import DanceDiffusionPipeline
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/dance_diffusion/init.py", line 1, in
from .pipeline_dance_diffusion import DanceDiffusionPipeline
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py", line 21, in
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py", line 39, in
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/init.py", line 31, in
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_condition_flax.py", line 25, in
from .modeling_flax_utils import FlaxModelMixin
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 45, in
class FlaxModelMixin:
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 192, in FlaxModelMixin
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Traceback (most recent call last):
File "/usr/local/bin/accelerate", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 43, in main
args.func(args)
File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 837, in launch_command
simple_launcher(args)
File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 354, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/usr/bin/python3', '/content/diffusers/examples/dreambooth/train_dreambooth.py', '--image_captions_filename', '--train_only_unet', '--save_starting_step=500', '--save_n_steps=500', '--Session_dir=/content/gdrive/MyDrive/Fast-Dreambooth/Sessions/t', '--pretrained_model_name_or_path=/content/stable-diffusion-custom', '--instance_data_dir=/content/gdrive/MyDrive/Fast-Dreambooth/Sessions/t/instance_images', '--output_dir=/content/models/t', '--captions_dir=/content/gdrive/MyDrive/Fast-Dreambooth/Sessions/t/captions', '--instance_prompt=', '--seed=805484', '--resolution=512', '--mixed_precision=fp16', '--train_batch_size=1', '--gradient_accumulation_steps=1', '--use_8bit_adam', '--learning_rate=2e-06', '--lr_scheduler=linear', '--lr_warmup_steps=0', '--max_train_steps=1500']' returned non-zero exit status 1.
Something went wrong
The text was updated successfully, but these errors were encountered: