Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention' #1766

Open
jpontalba opened this issue Dec 6, 2022 · 0 comments
Open

Comments

@jpontalba
Copy link

Description

ImportError thrown after importing libraries
...

Environment information

trax 1.4.1

OS: Ubuntu 

$ pip freeze | grep trax
trax                         1.4.1

$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensor2tensor==1.15.7
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-addons==0.18.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.11.0
tensorflow-gan==2.1.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.7.0
tensorflow-text==2.11.0
tensorstore==0.1.28

$ pip freeze | grep jax
jax==0.3.25
jaxlib==0.3.25

$ python -V
Python 3.8.10

For bugs: reproduction and error logs

# Steps to reproduce:
!pip install -q -U trax

import numpy as np  # regular ol' numpy

from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp  # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature
# Error logs:
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[22], line 3
      1 import numpy as np  # regular ol' numpy
----> 3 from trax import fastmath
      4 from trax import layers as tl
      5 from trax import shapes

File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/__init__.py:18
      1 # coding=utf-8
      2 # Copyright 2021 The Trax Authors.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 """Trax top level import."""
---> 18 from trax import data
     19 from trax import fastmath
     20 from trax import layers

File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/data/__init__.py:70
     67 from trax.data.inputs import UnBatch
     68 from trax.data.inputs import UniformlySeek
---> 70 from trax.data.tf_inputs import add_eos_to_output_features
     71 from trax.data.tf_inputs import BertGlueEvalStream
...
     35 from trax.layers.attention import SplitIntoHeads
     38 # Layers are always CamelCase, but functions in general are snake_case
     39 # pylint: disable=invalid-name

ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant