TorchSharp icon indicating copy to clipboard operation
TorchSharp copied to clipboard

Exception occurred when getting the state dictionary(state_dict).

Open HCareLou opened this issue 1 year ago • 34 comments

image image image

In TorchSharp, I defined a model that contains nn.Sequential cv4. However, when I obtained the state_dict of the entire model, the dictionary for cv4 was missing, which is very strange. Other models also have nn.Sequential, and they can all be correctly obtained, but not the last layer.

HCareLou avatar Mar 17 '24 10:03 HCareLou

I suspect it's because of this API that it fails to recognize all the modules. this.add_module(nameof(model), this.model);

HCareLou avatar Mar 17 '24 15:03 HCareLou

image image When the elements of ModuleList are nn.Module<Tensor, Tensor> instead of nn.Sequential, the state_dict of the ModuleList elements can be captured correctly.

HCareLou avatar Mar 18 '24 00:03 HCareLou

When calling model.add_module, it traverses its sub-items and calls RegisterComponents() for each sub-item. However, Sequential does not perform such an operation. As a temporary measure, in a custom model, after initialization, force the call to RegisterComponents(). image

HCareLou avatar Mar 18 '24 07:03 HCareLou

Comparing Sequential and ModuleList, it can be observed that ModuleList overrides RegisterComponents, whereas Sequential does not. Because ModuleList overrides RegisterComponents, it gains the ability to automatically invoke the RegisterComponents of its child items. image image

HCareLou avatar Mar 18 '24 08:03 HCareLou

Is there a smallish repro case that I can debug?

NiklasGustafsson avatar Mar 18 '24 15:03 NiklasGustafsson

public class AModel : TorchSharp.torch.nn.Module
{
    public TorchSharp.Modules.Sequential cv1;
    public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
    public AModel():base(nameof(AModel))
    {
        cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
        cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
        RegisterComponents();
    }
}

public class BModel : torch.nn.Module<torch.Tensor,torch.Tensor>
{
    public Tensor stride;
    public Tensor stride2;


    public BModel() : base(nameof(BModel))
    {
        stride = torch.ones(100, 100);
        stride2 = torch.ones(100, 100);
    }

    public override torch.Tensor forward(torch.Tensor input)
    {
        throw new NotImplementedException();
    }
}
var a = new AModel();
var aStateDict = a.state_dict();

In this case, the state_dict of cv1 cannot be obtained.

public class AModel : TorchSharp.torch.nn.Module
{
    public TorchSharp.Modules.Sequential cv1;
    public TorchSharp.Modules.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>> cv2;
    public AModel() : base(nameof(AModel))
    {
        cv1 = nn.Sequential(Enumerable.Range(0, 3).Select(x => new BModel()));
        cv2 = nn.ModuleList<torch.nn.Module<torch.Tensor, torch.Tensor>>(Enumerable.Range(0, 3).Select(x => new BModel()).ToArray());
        RegisterComponents();
    }
}

public class BModel : torch.nn.Module<torch.Tensor, torch.Tensor>
{
    public Tensor stride;
    public Tensor stride2;


    public BModel() : base(nameof(BModel))
    {
        stride = torch.ones(100, 100);
        stride2 = torch.ones(100, 100);
        RegisterComponents();
    }

    public override torch.Tensor forward(torch.Tensor input)
    {
        throw new NotImplementedException();
    }
}
var a = new AModel();
var aStateDict = a.state_dict();

In this case, the state_dict of cv1 can be obtained.

HCareLou avatar Mar 19 '24 00:03 HCareLou

So I guess the problem is that Sequential.RegisterComponents won't call RegisterComponents on its submodules.

yueyinqiu avatar Mar 19 '24 01:03 yueyinqiu

Sequential and ModuleList have different implementation methods for RegisterComponents.

HCareLou avatar Mar 19 '24 01:03 HCareLou

I suppose the issue could be simply solved by adding a call to the submodule's RegisterComponents in Sequential.Add.

However actually in my opinion, all the modules should always call RegisterComponents themselves (or register the modules, parameters, buffers in other ways), so there is no need to deal with the submodules because they will do that on their own.

But it seems not... Even ModuleList and Sequential are not doing that... I'm a bit confused now... This makes it impossible to use:

using static TorchSharp.torch.nn;

var l = ModuleList(Linear(1, 1));
Console.WriteLine(l.state_dict().Count); // 0

And RegisterComponents is protected so I have to create a wrapping module outside? I believe something have been ill designed here...

yueyinqiu avatar Mar 19 '24 02:03 yueyinqiu

 protected override void RegisterComponents()
 {
     if (_registered) return;

     for (int i = 0; i < _list.Count; i++) {
         register_module($"{i}", _list[i]);
     }
     _registered = true;
 }

This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.

HCareLou avatar Mar 19 '24 05:03 HCareLou

 protected override void RegisterComponents()
 {
     if (_registered) return;

     for (int i = 0; i < _list.Count; i++) {
         register_module($"{i}", _list[i]);
     }
     _registered = true;
 }

This is the implementation of ModuleList, and I think adding similar code in Sequential should be able to fix it.

However you will still find that it's unable to use Sequential(BModel()).state_dict(), since RegisterComponents of Sequential will not be called by itself, so your models' RegisterComponents is also not invoked. That's probably because Sequential allows models to be dynamically appended, so we have to register them dynamically, instead of calling RegisterComponents only once.

So one solution might be to call the submodule's RegisterComponents in Sequential.Add. However it might make RegisterComponents be called too early, especially when the submodules are also mutable. I'm not sure what the expected behavior should be.

And let me repeat my suggestion. Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

yueyinqiu avatar Mar 19 '24 09:03 yueyinqiu

We can call RegisterComponents once in the top-level model, so that other models will be registered automatically, which is the most convenient.

HCareLou avatar Mar 19 '24 09:03 HCareLou

ahh... Even RegisterComponents of the submodules cannot be accessed by Sequential since it's protected. Now I have no idea how to implement it without breaking other things... (ModuleList uses register_model, but currently Sequential does not, which keeps a List<torch.nn.IModule<Tensor, Tensor>> instead.)

yueyinqiu avatar Mar 19 '24 12:03 yueyinqiu

image I wonder if this could serve as a relatively good solution.

HCareLou avatar Mar 20 '24 00:03 HCareLou

I think this could work without side effects... but... humm... I can't say...

protected override void RegisterComponents()
{
    foreach(var module in this._modules) {
        this.register_module("sub", (nn.Module)module);
        _internal_submodules.Clear();
    }
}

yueyinqiu avatar Mar 20 '24 07:03 yueyinqiu

To facilitate the use of pre-trained weights in TorchSharp, it is advisable to maintain consistency with PyTorch as much as possible.

HCareLou avatar Mar 20 '24 07:03 HCareLou

To facilitate the use of pre-trained weights in PyTorch, it is advisable to maintain consistency with PyTorch as much as possible.

That is what I mean. A module should register the parameters by themselves. In PyTorch it is done by __setattr__ and __getattr__ of the module. However it's impossible for csharp, so there is RegisterComponents. If you want a module to behavior like PyTorch, then it should always call RegisterComponents in its constructor, rather than let it be called by others.

In other words, all the modules should be able to use alone, instead of being required to be a part of other modules. In PyTorch __setattr__ and __getattr__ could automatically deal with that. But in csharp, if you don't call RegisterComponents then it can't work correctly.

Umm... Perhaps the best solution would be a source generator?

yueyinqiu avatar Mar 20 '24 07:03 yueyinqiu

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

HCareLou avatar Mar 20 '24 08:03 HCareLou

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...

yueyinqiu avatar Mar 20 '24 08:03 yueyinqiu

Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.

HCareLou avatar Mar 20 '24 08:03 HCareLou

Yes I suppose Fody/SourceGenerator could be a beautiful solution. And we can easily expose properties instead of fields in that way. That's also great.

@NiklasGustafsson Could you please take a look at this?

yueyinqiu avatar Mar 20 '24 08:03 yueyinqiu

I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister

PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(

yueyinqiu avatar Mar 20 '24 12:03 yueyinqiu

I'm not sure if this is feasible or not, but the idea is to call RegisterComponents within the parameterless constructor of nn.Module. This way, when you create a custom model that inherits from nn.Module, it will automatically register itself.

Unfortunately, that's impossible. The constructor of nn.Module is runned before the constructor of your model. Thus no values have been assigned to the properties and fields, so we can't register them...

That is exactly right. That's why RegisterComponents exists and needs to be called last in the (custom) module constructor.

NiklasGustafsson avatar Mar 20 '24 19:03 NiklasGustafsson

Switching gears, we could declare properties and then mark them with a custom attribute that has a “name” which will be used as the registration name. Subsequently, we could employ Fody to inject code into the getter and setter methods to handle the registration process. However, this approach is somewhat cumbersome.

That capability already exists. For example, in the rewrite we're working on for some of the standard modules, which will enable more attributes to be exposed, the parameters of Linear are defined as:

            const string WeightComponentName = nameof(weight);
            const string BiasComponentName = nameof(bias);

            public Parameter? bias {
                get => _bias;
                set {
                    _bias?.Dispose();
                    _bias = value?.DetachFromDisposeScope() as Parameter;
                    ConditionallyRegisterParameter(BiasComponentName, _bias);
                }
            }

            public Parameter weight {
                get => _weight!;
                set {
                    if (value is null) throw new ArgumentNullException(nameof(weight));
                    if (value.Handle != _weight?.Handle) {
                        _weight?.Dispose();
                        _weight = (value.DetachFromDisposeScope() as Parameter)!;
                        ConditionallyRegisterParameter(WeightComponentName, _weight);
                    }
                }
            }

            [ComponentName(Name = BiasComponentName)]
            private Parameter? _bias;
            [ComponentName(Name = WeightComponentName)]
            private Parameter? _weight;

NiklasGustafsson avatar Mar 20 '24 19:03 NiklasGustafsson

Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules

The exception from this rule will be modules (such as the Linear module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.

NiklasGustafsson avatar Mar 20 '24 19:03 NiklasGustafsson

I have made a simplified demo here: https://github.com/yueyinqiu/TorchSharp.AutoRegister

PS: It's still impossible to get rid of the traditional constructors (and use a primary constructor instead), because we have to access the generated property. So sad :(

As much as I dislike relying on reflection, which the current scheme does (I dislike it because it prevents AOT), having to use source code generation adds complexity and something that has to be automated. That would be a last resort, I think.

The current scheme works fairly well as long as you follow the instructions very closely and don't do advanced stuff like the Linear module above. I don't know why you would allow setting the parameters after the module has been constructed, but PyTorch does, so TorchSharp should, too.

NiklasGustafsson avatar Mar 20 '24 19:03 NiklasGustafsson

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

HCareLou avatar Mar 21 '24 00:03 HCareLou

Perhaps all the modules should always RegisterComponents themselves, and the parent modules should not care about registering the sub-submodules of its submodules (that is, not to call RegisterComponents on its submodules).

Yes, this is exactly the intended protocol, and the documentation says so: https://github.com/dotnet/TorchSharp/wiki/Creating-Your-Own-TorchSharp-Modules

The exception from this rule will be modules (such as the Linear module shown above, where the parameters may be altered, and the property setter needs to conditionally register a component, which allows you to assign it 'null' as well as overwrite an already existing parameter.

So my understanding is that custom modules should not relies on others calling its RegisterComponents. But why we are doing that in register_module? I think this may cause a misleading.

yueyinqiu avatar Mar 21 '24 01:03 yueyinqiu

So, to address the issue of Sequential not registering its sub-modules, for the time being, should we also rewrite Sequential’s RegisterComponents, just like we did with ModuleList?

We can certainly reconsider the protocol for module registration for the future. However, if the guidelines for custom modules described in the Wiki article are followed, the current protocol works.

NiklasGustafsson avatar Mar 21 '24 16:03 NiklasGustafsson

[...] why we are doing that in register_module?

As unsatisfying as this answer is -- I don't recall.

NiklasGustafsson avatar Mar 21 '24 16:03 NiklasGustafsson