Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einops does not appear to handle Tensorflow tensors with dynamic shape #187

Open
jesnie opened this issue May 13, 2022 · 3 comments
Open

Comments

@jesnie
Copy link

jesnie commented May 13, 2022

Describe the bug
There appears to be several bugs when trying to use einops with TensorFlow tensors that have a dynamic shape.

Reproduction steps

Steps to reproduce the behavior:

from einops import rearrange, parse_shape
import tensorflow as tf

def f(a: tf.Tensor) -> tf.Tensor:
    parsed_shape = parse_shape(a, "... n m")
    joined = rearrange(a, "... n m -> ... (n m)")
    # Hypothetically do something interesting with `joined` here.
    return rearrange(joined, "... (n m) -> ... n m", **parsed_shape)

static_shape = tf.Variable(tf.zeros((3, 4, 5)))
dynamic_shape = tf.Variable(tf.zeros((3, 4, 5)), shape=tf.TensorShape(None))

f(static_shape)  # <-- Works.
f(dynamic_shape)  # <-- Crashes.

Expected behavior
f should run without crashing, and do the same for both static_shape and dynamic_shape.

Your platform

  • Python: 3.10.4
  • tensorflow: 2.8.0
  • einops: 0.4.1
@vdutor
Copy link

vdutor commented Jul 26, 2022

Hi @arogozhnikov, thanks for your great library. Any update on this issue?

Building on the example above, the following also breaks, where the shape of the tensor is dynamic but the rank is known:

dynamic_shape_known_rank = tf.Variable(tf.zeros((3, 4, 5)), shape=tf.TensorShape([None, None, None]))

f(dynamic_shape_known_rank)  # <-- Crashes.

Relevant package versions:

einops: 0.4.1
tensorflow: 2.8.0
python: 3.9

@arogozhnikov
Copy link
Owner

arogozhnikov commented Oct 12, 2022

Hi @jesnie ,
to make e.g. your second example work, einops needs at least symbols for shape components - otherwise there is no way to specify how to reshape a tensor.

The only solution I see right now is to externally use tf.set_shape.
Previously, AFAIR, tf had some of its methods (I think tf.get_shape) to return symbol for shape. Then it was deleted.

I've sent a proposal for generic API that will make such operations possible, and it certainly needs support now:
data-apis/array-api#494

@arogozhnikov arogozhnikov added backend bug and removed bug Something isn't working labels Oct 12, 2022
@jesnie
Copy link
Author

jesnie commented Nov 6, 2022

I'm not sure what you mean by "symbols for shape components". I understand if this is a lot of work, and outside the scope of einops, but it is definitely possible to implement something like this. You "just" need to do all the maths and shaping and stuff in TensorFlow, and completely avoid pure Python types.

Here's a small, dirty, example:

import re
from typing import Any, Dict, List, Mapping, Tuple, Union

import tensorflow as tf

EllipsisType = Any
DimSpec = Union[int, Tuple[int, ...], EllipsisType]
TensorSpec = Tuple[DimSpec, ...]
RearrangeSpec = Tuple[TensorSpec, TensorSpec]


DIM_SPEC_RE = re.compile(r"(\.\.\.)|(\w+)|\((.*?)\)")


def parse_tensor_spec(spec: str) -> TensorSpec:
    result: List[DimSpec] = []
    for match in DIM_SPEC_RE.finditer(spec):
        i = match.lastindex
        if i == 1:
            result.append(...)
        elif i == 2:
            var_name = match.group(2)
            assert var_name is not None
            result.append(var_name)
        else:
            assert i == 3
            var_names = match.group(3)
            assert var_names is not None
            result.append(tuple(var_names.split()))
    return tuple(result)


def parse_rearrange_spec(spec: str) -> RearrangeSpec:
    frm, to = spec.split("->")
    return parse_tensor_spec(frm), parse_tensor_spec(to)


def flatten_tensor_spec(spec: TensorSpec) -> Tuple[int, ...]:
    result: List[int] = []
    for dim_spec in spec:
        if isinstance(dim_spec, tuple):
            result.extend(dim_spec)
        else:
            result.append(dim_spec)
    return tuple(result)


def parse_shape(t: tf.Tensor, spec: str) -> Mapping[str, tf.Tensor]:
    frm_spec = parse_tensor_spec(spec)
    frm_shape = tf.shape(t)
    i = tf.zeros((), dtype=tf.int32)
    sizes = {}
    for var_name in frm_spec:
        if isinstance(var_name, str):
            sizes[var_name] = frm_shape[i]
            i += 1
        elif isinstance(var_name, tuple):
            i += 1
        else:
            assert var_name is ...
            i += tf.size(frm_shape) - len(frm_spec) + 1
    return sizes


def rearrange(t: tf.Tensor, spec: str, **sizes: Union[int, tf.Tensor]) -> tf.Tensor:
    frm_spec, to_spec = parse_rearrange_spec(spec)
    tf_sizes: Dict[Union[str, EllipsisType], tf.Tensor] = {
        dim_name: tf.reshape(tf.convert_to_tensor(size), [1])
        for dim_name, size in sizes.items()
    }

    i = tf.zeros((), dtype=tf.int32)
    frm_shape = tf.shape(t)
    for var_name in frm_spec:
        if isinstance(var_name, str):
            tf_sizes[var_name] = frm_shape[i : i + 1]
            i += 1
        elif isinstance(var_name, tuple):
            i += 1
        else:
            assert var_name is ...
            size = tf.size(frm_shape) - len(frm_spec) + 1
            tf_sizes[var_name] = frm_shape[i : i + size]
            i += size

    frm_spec_flat = flatten_tensor_spec(frm_spec)
    frm_shape_flat = tf.concat([tf_sizes[n] for n in frm_spec_flat], axis=0)
    t = tf.reshape(t, frm_shape_flat)

    i = tf.zeros((), dtype=tf.int32)
    frm_spec_indices = {}
    for var_name in frm_spec_flat:
        size = tf.size(tf_sizes[var_name])
        frm_spec_indices[var_name] = tf.range(i, i + size)
        i += size
    to_spec_flat = flatten_tensor_spec(to_spec)
    perm = tf.concat([frm_spec_indices[n] for n in to_spec_flat], axis=0)
    t = tf.transpose(t, perm)

    to_shape = []
    for dim_spec in to_spec:
        if isinstance(dim_spec, tuple):
            if not dim_spec:
                to_shape.append(tf.ones((), dtype=tf.int32))
            else:
                to_shape.append(
                    tf.math.reduce_prod([tf_sizes[n] for n in dim_spec], axis=0)
                )
        else:
            to_shape.append(tf_sizes[dim_spec])
    t = tf.reshape(t, tf.concat(to_shape, axis=0))

    return t


def f(a: tf.Tensor) -> tf.Tensor:
    parsed_shape = parse_shape(a, "... n m")
    joined = rearrange(a, "... n m -> ... (n m)")
    # Hypothetically do something interesting with `joined` here.
    return rearrange(joined, "... (n m) -> ... n m", **parsed_shape)


static_shape = tf.Variable(tf.zeros((2, 3, 4, 5)))
dynamic_shape = tf.Variable(tf.zeros((2, 3, 4, 5)), shape=tf.TensorShape(None))

f(static_shape)
f(dynamic_shape)

compiled_f = tf.function(f)

compiled_f(static_shape)
compiled_f(dynamic_shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants