Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MiscTests.test_map_side_effects to work with nn module inlining #126355

Closed
laithsakka opened this issue May 15, 2024 · 7 comments
Closed
Assignees
Labels
module: dynamo module: nn Related to torch.nn oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@laithsakka
Copy link
Contributor

laithsakka commented May 15, 2024

TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python test/dynamo/test_misc.py -k MiscTests.test_map_side_effects
error:

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data/users/lsakka/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2756, in wrapper
    method(*args, **kwargs)
  File "/data/users/lsakka/pytorch/pytorch/test/dynamo/test_misc.py", line 5384, in test_map_side_effects
    with self.assertRaisesRegex(
AssertionError: "Can't inplace modify module params/buffers" does not match "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)

from user code:
   File "/data/users/lsakka/pytorch/pytorch/test/dynamo/test_misc.py", line 5381, in forward
    return map(body, xs)
  File "/data/users/lsakka/pytorch/pytorch/test/dynamo/test_misc.py", line 5378, in body
    self.w += 1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

cc @mruberry @jbschlosser @walterddr @mikaylagawarecki @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@laithsakka laithsakka added the module: nn Related to torch.nn label May 16, 2024
@laithsakka laithsakka changed the title Inlining nn modules for test ] dynamo/test_misc.py Inlining nn modules for test dynamo/test_misc.py May 16, 2024
@laithsakka
Copy link
Contributor Author

First anyway its not allowed to mutate the nn module variable from with in the map operation, the error message is different due to the following:
when nn module inlining not enabled, the code in builtin.py

     elif isinstance(obj, variables.NNModuleVariable):
            if not tx.output.is_root_tracer():
                raise AttributeMutationError(
                    "Can't inplace modify module params/buffers inside HigherOrderOp"
                )

is triggered, this is not triggered when inlining is enabled because obj in that case is UnspecializedNNModuleVariable.
instead what happen is that this check is by passed and later on another check is triggered when set_attribute is called
on UserDefinedObjectVariable:
tx.output.side_effects.store_attr(self, name, value)
-->

       # People do things like self.dim = dim inside autograd.Function.
        # These are benign.
        if isinstance(item, AutogradFunctionContextVariable):
            return True
        if not is_side_effect_safe(item.mutable_local):
            unimplemented(
                "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
            )

@laithsakka
Copy link
Contributor Author

laithsakka commented May 16, 2024

We have two options here:

  1. do not allow mutations even for UnspecializedNNModuleVariable.
  2. change the unit test to expect different messages if inlining is enabled vs disabled,
    or just check that unsupported is thrown.
    @anijain2305 sounds like (2) is the appropriate let me know otherwise.

@ezyang
Copy link
Contributor

ezyang commented May 16, 2024

IMO it should be possible for Dynamo to trace through NN module mutations

@laithsakka
Copy link
Contributor Author

laithsakka commented May 16, 2024

IMO it should be possible for Dynamo to trace through NN module mutations
it does in general, but it does not here when it that happens from within a ma high order operation.
ex: (this work)

class Module(torch.nn.Module):
   def __init__(self):
           super().__init__()
           self.w = torch.tensor(1)
           
    def forward(self, xs):
            def body(x):
                    self.w += 1
                    return x

                return body(xs)

ex2: this does not work


class Module(torch.nn.Module):
  def __init__(self):
          super().__init__()
          self.w = torch.tensor(1)
 def forward(self, xs):
               def body(x):
                   self.w += 1
                   return x

               return map(body, xs)

@ezyang do you think we should make ex2 also work above ? it is treated with inlining enabled as a UserDefinedObjectVariable

@laithsakka laithsakka changed the title Inlining nn modules for test dynamo/test_misc.py Enable MiscTests.test_map_side_effects to work with nn module inlining May 16, 2024
@xmfan xmfan added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels May 16, 2024
@ezyang
Copy link
Contributor

ezyang commented May 17, 2024

If map is the hop, no I don't expect it to work, but this should have failed without nn module inlining too :P

@laithsakka
Copy link
Contributor Author

If map is the hop, no I don't expect it to work, but this should have failed without nn module inlining too :P

yep it is failing both ways. Just the error message changed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: nn Related to torch.nn oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants