Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing a float in repeat as a dimension size prevents correct usage afterwards #309

Open
Maykeye opened this issue Feb 20, 2024 · 2 comments
Labels
wontfix This will not be worked on

Comments

@Maykeye
Copy link

Maykeye commented Feb 20, 2024

Consider the scenario

  1. Working environment is jupyter notebook(that's where this bug affects code in the wild)

  2. I call repeat(image, 'h -> (h H)', H=H) with H=2.0 (float, not integer). It fails complaining about types. OK, fair (in real code I used height/factor where both height and factor are integers, in this simplified scenario 2.0 will suffice)

  3. I call it again with H=2 (int). It fails. What? But it fails if and only if I did step (1). If I don't, it works fine.

  4. I did some digging and saw that functions use LRU cache.

  5. So working theory is if you call repeat improperly, you will taint the cache.

  6. I call repeat enough of times with different arguments to cleanse LRU cache

  7. I call it again with H=2 (int). Everything works fine.

Here's full code to reproduce.

import torch
from einops import repeat
image = torch.zeros(1)

def run_repeat(H):
    return repeat(image, 'h -> (h H)', H=H)

try: 
    fail_of_course= run_repeat(2.0)
except:
    print("mistakes were made as expected")

try: 
    should_not_fail = run_repeat(2)
except:
    print("we still failed")

for i in range(3, 2048):
    try: 
        cache_away = run_repeat(i)
    except:
        print("fail with i=",i)
try: 
    should_not_fail = run_repeat(2)
    print("but this time we are fine")
except:
    print("we still failed #2")

The output is

mistakes were made as expected
we still failed
but this time we are fine

If we are not hiding exception in should_not_fail, the full unexpected error is

  File "/tmp/a.py", line 13, in <module>
    should_not_fail = run_repeat(2)
                      ^^^^^^^^^^^^^
  File "/tmp/a.py", line 6, in run_repeat
    return repeat(image, 'h -> (h H)', H=H)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/einops.py", line 641, in repeat
    return reduce(tensor, pattern, reduction="repeat", **axes_lengths)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/einops.py", line 523, in reduce
    return _apply_recipe(
           ^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/einops.py", line 248, in _apply_recipe
    tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/_backends.py", line 267, in add_axes
    return x.expand(repeats)
           ^^^^^^^^^^^^^^^^^
TypeError: expand(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got float"

and if repeat is never called with float, only

but this time we are fine

is displayed

Einops 0.7.0.
Torch. 2.2.0+cu121
Python '3.11.7 (main, Jan 29 2024, 16:03:57) [GCC 13.2.1 20230801]'

ETA: yep, replacing in _apply_recipe a call to _reconstruct_from_shape with _reconstruct_from_shape_uncached gets rid of error. So the bug happens because hash(2.0) = hash(2)

(Also functools.lru_cache(1024, typed=True) doesn't help as bandaid solution as from the lru_cache's PoV function always receives the same type axes_lengths: HashableAxesLengths)

@Maykeye Maykeye added the bug Something isn't working label Feb 20, 2024
@arogozhnikov
Copy link
Owner

arogozhnikov commented Feb 23, 2024

Hi @Maykeye
You're correct, reason is in lru_cache.
More deeply, that's because hash(2.0) == hash(2) and 2 == 2.0, so map[2] == map[2.0]

Switching from plain LRU to typed LRU would solve this problem, but incurs a slow-down, so I'll just to accept current behavior and treat this situation as a programming error (i.e. user should fix it).

As a recovery: _prepare_transformation_recipe.cache_clear() or restart a kernel (or just overflow cache - also works).

@arogozhnikov arogozhnikov added wontfix This will not be worked on and removed bug Something isn't working labels Feb 23, 2024
@arogozhnikov arogozhnikov changed the title [BUG] Proper use of repeat fails if before it was used improperly Passing a float in repeat as a dimension size prevents correct usage afterwards Feb 23, 2024
@Maykeye
Copy link
Author

Maykeye commented Mar 1, 2024

(i.e. user should fix it).

The fix is that is appropriate for the user is changing repeat(image, 'h -> (h H)', H=foo/bar) #incorrect call to repeat(image, 'h -> (h H)', H=foo//bar) #correct call in the jupyter cell and rerunning it without seeing a error message about float once again.

_prepare_transformation_recipe.cache_clear is not documented.

The fact it is named _prepare_transformation_recipe rather than prepare_transformation_recipe even suggests that user should not know about its very existence.

Maybe just add an assert in _reconstruct_from_shape_uncached when it's iterating over dimensions that assert not isinstance(dim, float), "dim can't be float!": this way cache will not be filled with floats to begin with, and since the result is cached anyway, one call to assert will not slow the world down (besides asserts can even be disabled with python -O).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants