Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception occurred when getting the state dictionary(state_dict). #1272

Open
lintao185 opened this issue Mar 17, 2024 · 33 comments
Open

Comments

@lintao185
Copy link

image
image
image

In TorchSharp, I defined a model that contains nn.Sequential cv4. However, when I obtained the state_dict of the entire model, the dictionary for cv4 was missing, which is very strange. Other models also have nn.Sequential, and they can all be correctly obtained, but not the last layer.

@lintao185
Copy link
Author

I suspect it's because of this API that it fails to recognize all the modules.
this.add_module(nameof(model), this.model);

@lintao185
Copy link
Author

image
image
When the elements of ModuleList are nn.Module<Tensor, Tensor> instead of nn.Sequential, the state_dict of the ModuleList elements can be captured correctly.

@lintao185
Copy link
Author

When calling model.add_module, it traverses its sub-items and calls RegisterComponents() for each sub-item. However, Sequential does not perform such an operation. As a temporary measure, in a custom model, after initialization, force the call to RegisterComponents().
image

@lintao185
Copy link
Author

Comparing Sequential and ModuleList, it can be observed that ModuleList overrides RegisterComponents, whereas Sequential does not. Because ModuleList overrides RegisterComponents, it gains the ability to automatically invoke the RegisterComponents of its child items.
image
image

@NiklasGustafsson
Copy link
Contributor

Is there a smallish repro case that I can debug?

@lintao185
Copy link
Author

public class AModel : TorchSharp.torch.nn.Module
{
    public TorchSharp.Modules.Sequential cv1;
    public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
    public AModel():base(nameof(AModel))
    {
        cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
        cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
        RegisterComponents();
    }
}

public class BModel : torch.nn.Module<torch.Tensor,torch.Tensor>
{
    public Tensor stride;
    public Tensor stride2;


    public BModel() : base(nameof(BModel))
    {
        stride = torch.ones(100, 100);
        stride2 = torch.ones(100, 100);
    }

    public override torch.Tensor forward(torch.Tensor input)
    {
        throw new NotImplementedException();
    }
}
var a = new AModel();
var aStateDict = a.state_dict();

In this case, the state_dict of cv1 cannot be obtained.

public class AModel : TorchSharp.torch.nn.Module
{
    public TorchSharp.Modules.Sequential cv1;
    public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
    public AModel() : base(nameof(AModel))
    {
        cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
        cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
        RegisterComponents();
    }
}

public class BModel : torch.nn.Module<torch.Tensor, torch.Tensor>
{
    public Tensor stride;
    public Tensor stride2;


    public BModel() : base(nameof(BModel))
    {
        stride = torch.ones(100, 100);
        stride2 = torch.ones(100, 100);
        RegisterComponents();
    }

    public override torch.Tensor forward(torch.Tensor input)
    {
        throw new NotImplementedException();
    }
}
var a = new AModel();
var aStateDict = a.state_dict();

In this case, the state_dict of cv1 can be obtained.

@yueyinqiu
Copy link
Contributor

So I guess the problem is that Sequential.RegisterComponents won't call RegisterComponents on its submodules.

@lintao185
Copy link
Author

Sequential and ModuleList have different implementation methods for RegisterComponents.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 19, 2024

I suppose the issue could be simply solved by adding a call to the submodule's RegisterComponents in Sequential.Add.

However actually in my opinion, all the modules should always call RegisterComponents themselves (or register the modules, parameters, buffers in other ways), so there is no need to deal with the submodules because they will do that on their own.

But it seems not... Even ModuleList and Sequential are not doing that... I'm a bit confused now... This makes it impossible to use:

using static TorchSharp.torch.nn;

var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0

And RegisterComponents is protected so I have to create a wrapping module outside? I believe something have been ill designed here...

@lintao185
Copy link
Author

 protected override void RegisterComponents()
 {
     if (_registered) return;

     for (int i = 0; i < _list.Count; i++) {
         register_module($"{i}", _list[i]);
     }
     _registered = true;
 }

This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 19, 2024

 protected override void RegisterComponents()
 {
     if (_registered) return;

     for (int i = 0; i < _list.Count; i++) {
         register_module($"{i}", _list[i]);
     }
     _registered = true;
 }

This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.

However you will still find that it's unable to use Sequential(BModel()).state_dict(), since RegisterComponents of Sequential will not be called by itself, so your models' RegisterComponents is also not invoked. That's probably because Sequential allows models to be dynamically appended, so we have to register them dynamically, instead of calling RegisterComponents only once.

So one solution might be to call the submodule's RegisterComponents in Sequential.Add. However it might make RegisterComponents be called too early, especially when the submodules are also mutable. I'm not sure what the expected behavior should be.

And let me repeat my suggestion. Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

@lintao185
Copy link
Author

We can call RegisterComponents once in the top-level model, so that other models will be registered automatically, which is the most convenient.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 19, 2024

ahh... Even RegisterComponents of the submodules cannot be accessed by Sequential since it's protected. Now I have no idea how to implement it without breaking other things... (ModuleList uses register_model, but currently Sequential does not, which keeps a List<torch.nn.IModule<Tensor, Tensor>> instead.)

@lintao185
Copy link
Author

image
I wonder if this could serve as a relatively good solution.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 20, 2024

I think this could work without side effects... but... humm... I can't say...

protected override void RegisterComponents()
{
    foreach(var module in this._modules) {
        this.register_module("sub", (nn.Module)module);
        _internal_submodules.Clear();
    }
}

@lintao185
Copy link
Author

lintao185 commented Mar 20, 2024

To facilitate the use of pre-trained weights in TorchSharp, it is advisable to maintain consistency with PyTorch as much as possible.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 20, 2024

To facilitate the use of pre-trained weights in PyTorch, it is advisable to maintain consistency with PyTorch as much as possible.

That is what I mean. A module should register the parameters by themselves. In PyTorch it is done by __setattr__ and __getattr__ of the module. However it's impossible for csharp, so there is RegisterComponents. If you want a module to behavior like PyTorch, then it should always call RegisterComponents in its constructor, rather than let it be called by others.

In other words, all the modules should be able to use alone, instead of being required to be a part of other modules. In PyTorch __setattr__ and __getattr__ could automatically deal with that. But in csharp, if you don't call RegisterComponents then it can't work correctly.

Umm... Perhaps the best solution would be a source generator?

@lintao185
Copy link
Author

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

@yueyinqiu
Copy link
Contributor

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...

@lintao185
Copy link
Author

Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 20, 2024

Yes I suppose Fody/SourceGenerator could be a beautiful solution. And we can easily expose properties instead of fields in that way. That's also great.

@NiklasGustafsson Could you please take a look at this?

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 20, 2024

I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister

PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(

@NiklasGustafsson
Copy link
Contributor

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...

That is exactly right. That's why RegisterComponents exists and needs to be called last in the (custom) module constructor.

@NiklasGustafsson
Copy link
Contributor

Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.

That capability already exists. For example, in the rewrite we're working on for some of the standard modules, which will enable more attributes to be exposed, the parameters of Linear are defined as:

            const string WeightComponentName = nameof(weight);
            const string BiasComponentName = nameof(bias);

            public Parameter? bias {
                get => _bias;
                set {
                    _bias?.Dispose();
                    _bias = value?.DetachFromDisposeScope() as Parameter;
                    ConditionallyRegisterParameter(BiasComponentName, _bias);
                }
            }

            public Parameter weight {
                get => _weight!;
                set {
                    if (value is null) throw new ArgumentNullException(nameof(weight));
                    if (value.Handle != _weight?.Handle) {
                        _weight?.Dispose();
                        _weight = (value.DetachFromDisposeScope() as Parameter)!;
                        ConditionallyRegisterParameter(WeightComponentName, _weight);
                    }
                }
            }

            [ComponentName(Name = BiasComponentName)]
            private Parameter? _bias;
            [ComponentName(Name = WeightComponentName)]
            private Parameter? _weight;

@NiklasGustafsson
Copy link
Contributor

Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

Yes, this is exactly the intended protocol, and the documentation says so:
https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules

The exception from this rule will be modules (such as the Linear module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Mar 20, 2024

I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister

PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(

As much as I dislike relying on reflection, which the current scheme does (I dislike it because it prevents AOT), having to use source code generation adds complexity and something that has to be automated. That would be a last resort, I think.

The current scheme works fairly well as long as you follow the instructions very closely and don't do advanced stuff like the Linear module above. I don't know why you would allow setting the parameters after the module has been constructed, but PyTorch does, so TorchSharp should, too.

@lintao185
Copy link
Author

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 21, 2024

Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules

The exception from this rule will be modules (such as the Linear module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.

So my understanding is that custom modules should not relies on others calling its RegisterComponents. But why we are doing that in register_module? I think this may cause a misleading.

@NiklasGustafsson
Copy link
Contributor

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.

@NiklasGustafsson
Copy link
Contributor

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

@yueyinqiu
Copy link
Contributor

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

haha... Perhaps we should check all the modules provided by TorchSharp, and remove this call if nothing depends on that? I believe this should be done as early as possible, to avoid more projects' relying on it by mistake.

@yueyinqiu
Copy link
Contributor

yueyinqiu commented Mar 22, 2024

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

haha... Perhaps we should check all the modules provided by TorchSharp, and remove this call if nothing depends on that? I believe this should be done as early as possible, to avoid more projects' relying on it by mistake.

oh well there is at least one thing (ModuleList) that depends on that:

using static TorchSharp.torch.nn;

var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0

I suppose it should be modified to have a similar behavior as Sequential.

@lintao185
Copy link
Author

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.

Now it’s calling RegisterComponents() in the custom module, it just feels a bit verbose, haha.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants