Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pip seems to not be so happy with JAX #501

Open
xiki-tempula opened this issue Jun 5, 2023 · 1 comment
Open

pip seems to not be so happy with JAX #501

xiki-tempula opened this issue Jun 5, 2023 · 1 comment

Comments

@xiki-tempula
Copy link

I tried to install pymbar via pip in a conda env and there seems to be some JAX problem. The way to reproduce the problem. The conda install is fine though.

Create the env

conda create -n test ipython
conda activate test

Install pymbar

pip install pymbar

Test it

ipython
>>> from pymbar import mbar

Gives

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 from pymbar import mbar

File ~/mambaforge/envs/test/lib/python3.11/site-packages/pymbar/__init__.py:32
     29 __email__ = "levi.naden@choderalab.org,jaime.rodriguez-guerra@choderalab.org,michael.shirts@colorado.edu,john.chodera@choderalab.org"
     31 from . import timeseries, testsystems, confidenceintervals
---> 32 from .mbar import MBAR
     33 from .other_estimators import bar, bar_overlap, bar_zero, exp, exp_gauss
     34 from .fes import FES

File ~/mambaforge/envs/test/lib/python3.11/site-packages/pymbar/mbar.py:44
     42 import numpy as np
     43 import numpy.linalg as linalg
---> 44 from pymbar import mbar_solvers
     45 from pymbar.utils import (
     46     kln_to_kn,
     47     kn_to_n,
   (...)
     51     check_w_normalized,
     52 )
     54 logger = logging.getLogger(__name__)

File ~/mambaforge/envs/test/lib/python3.11/site-packages/pymbar/mbar_solvers.py:16
     14 if force_no_jax:
     15     raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py")
---> 16 from jax.config import config
     18 config.update("jax_enable_x64", True)
     20 from jax.numpy import exp, sum, newaxis, diag, dot, s_

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/__init__.py:35
     30 del _cloud_tpu_init
     32 # Confusingly there are two things named "config": the module and the class.
     33 # We want the exported object to be the class, so we first import the module
     34 # to make sure a later import doesn't overwrite the class.
---> 35 from jax import config as _config_module
     36 del _config_module
     38 # Force early import, allowing use of `jax.core` after importing `jax`.

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/config.py:17
      1 # Copyright 2018 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 
     15 # TODO(phawkins): fix users of this alias and delete this file.
---> 17 from jax._src.config import config  # noqa: F401

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/_src/config.py:24
     21 import threading
     22 from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
---> 24 from jax._src import lib
     25 from jax._src.lib import jax_jit
     26 from jax._src.lib import transfer_guard_lib

File ~/mambaforge/envs/test/lib/python3.11/site-packages/jax/_src/lib/__init__.py:74
     70   return _jaxlib_version
     72 version_str = jaxlib.version.__version__
     73 version = check_jaxlib_version(
---> 74   jax_version=jax.version.__version__,
     75   jaxlib_version=jaxlib.version.__version__,
     76   minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
     80 # Before importing any C compiled modules from jaxlib, first import the CPU
     81 # feature guard module to verify that jaxlib was compiled in a way that only
     82 # uses instructions that are present on this machine.
     83 import jaxlib.cpu_feature_guard as cpu_feature_guard

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
@mikemhenry
Copy link
Contributor

This should help:
#503

I am not sure when we will make another release but it will be optional soon ™️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants