pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Make `save_hyperparamters` not save another class's args when `LightningModule` instantiated within another class's constructor

Open tanmoyio opened this issue 3 years ago • 10 comments
trafficstars

Duplicate args conflict with .save_hyperparameters()

Simple hack to fix the issue https://github.com/PyTorchLightning/pytorch-lightning/issues/13181 by @serena-ruan

quick test

class Mod1(LightningModule):
    def __init__(self, same_arg):
        super().__init__()
        self.same_arg = same_arg
        self.save_hyperparameters()

class Mod2(LightningModule):
    def __init__(self, same_arg):
        super().__init__()
        self.same_arg = same_arg
        self.save_hyperparameters()

class Parent():
    def __init__(self, same_arg="parent", diff_arg="test"):
        super().__init__()
        self.m1 = Mod1(same_arg='child1')
        self.m2 = Mod2(same_arg='child2')

parent = Parent()
print(parent.m1.hparams, parent.m2.hparams)

Before submitting

  • [ ] Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [ ] Did you make sure to update the documentation with your changes? (if necessary)
  • [ ] Did you write any new necessary tests? (not for typos and docs)
  • [ ] Did you verify new and existing tests pass locally with your changes?
  • [ ] Did you list all the breaking changes introduced by this pull request?
  • [ ] Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR. Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • [ ] Is this pull request ready for review? (if not, please submit in draft mode)
  • [ ] Check that all items from Before submitting are resolved
  • [ ] Make sure the title is self-explanatory and the description concisely explains the PR
  • [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

tanmoyio avatar Jun 01 '22 14:06 tanmoyio

@tanmoyio could this introduce any breaking changes? I wasn't sure

awaelchli avatar Jun 02 '22 00:06 awaelchli

@awaelchli It will not. One unit test broke in test_parsing.py which was expected, I modified that.

tanmoyio avatar Jun 02 '22 14:06 tanmoyio

@tanmoyio Just one more question regarding the issue I raised, this PR will take children's param as priority, but it still saves whatever params existing in the parent. And in my issue the two classes are not inheritance relationship, so will you solve this problem in a separate PR?

serena-ruan avatar Jun 08 '22 07:06 serena-ruan

@serena-ruan

class Boring_saving_hparams(LightningModule):
    def __init__(self, same_arg):
        super().__init__()
        self.same_arg = same_arg
        self.save_hyperparameters()

class BoringParent():
    def __init__(self, same_arg="parent", diff_arg="test"):
        super().__init__()
        self.model = Boring_saving_hparams(same_arg="child")

in your code, even if the BoringParent is not inheriting it is still initializing a object of Boring_saving_hparams which uses save_hyperparameters(). So it will keep track of the hparams, and it should.

@awaelchli what do you think ?

tanmoyio avatar Jun 08 '22 10:06 tanmoyio

class Mod1(LightningModule):
    def __init__(self, same_arg):
        super().__init__()
        self.same_arg = same_arg
        self.save_hyperparameters()

class Mod2(LightningModule):
    def __init__(self, same_arg):
        super().__init__()
        self.same_arg = same_arg
        self.save_hyperparameters()

class Parent():
    def __init__(self, same_arg="parent", diff_arg="test"):
        super().__init__()
        self.m1 = Mod1(same_arg='child1')
        self.m2 = Mod2(same_arg='child2')

can you share some cases where such kind of inheritance is used? LightningModuleis a system. If your child components just contain some layers, they should be defined as nn.Modules.

rohitgr7 avatar Jun 22 '22 13:06 rohitgr7

@rohitgr7 reference https://github.com/Lightning-AI/lightning/issues/13181

tanmoyio avatar Jun 22 '22 14:06 tanmoyio

can you share some cases where such kind of inheritance is used? LightningModuleis a system. If your child components just contain some layers, they should be defined as nn.Modules.

@rohitgr7 I have the same thought and am still curious about the actual use cases (cc @tanmoyio @serena-ruan), but at the same time, I think we can't assume that a lightning module is always instantiated outside any class's __init__.

akihironitta avatar Jul 05 '22 04:07 akihironitta

in your code, even if the BoringParent is not inheriting it is still initializing a object of Boring_saving_hparams which uses save_hyperparameters(). So it will keep track of the hparams, and it should.

@tanmoyio I don't think it should since it's completely out of scope of PL. If one wants to log other args of the wrapping class, they should be passed to the lightning module's __init__.

In short, I believe the following assertions need to pass in this PR as @serena-ruan asks in https://github.com/Lightning-AI/lightning/pull/13202#issuecomment-1149552152:

from pytorch_lightning import LightningModule, Trainer

class BoringModel(LightningModule):
    def __init__(self, same_arg):
        super().__init__()
        self.save_hyperparameters()

class SomeClass:
    def __init__(self, same_arg="parent", diff_arg="test"):
        super().__init__()
        self.model = BoringModel(same_arg="child")
        assert self.model.hparams.same_arg == "child", f"'{self.model.hparams.same_arg}' should be 'child'."
        assert not hasattr(self.model.hparams, "diff_arg"), "`diff_arg` shouldn't be saved in `self.model.hparams`."

some_class = SomeClass()

akihironitta avatar Jul 05 '22 04:07 akihironitta

I propose to close this PR. Not having the save_hyperparameters() at the "root" module called is a misuse and should be prohibited. If it was done like this, we can't reliably instantiate the model using the load_from_checkpoint function.

awaelchli avatar Aug 08 '22 20:08 awaelchli

@awaelchli It's not clear to me if your argument also applies to the linked issue and whether it should be closed too as non-supported usage. Closing is your call, as you are the most familiar on this topic :)

carmocca avatar Aug 09 '22 15:08 carmocca

@carmocca Actually, I am simply referring to the use case given in the description here. I just realized it is not at all the same as in the linked issue.

The linked issue should be addressed, but it is clearly different. This PR tries to resolve a limitation with inheritance that I argue can't be logically supported. On the other hand, the reported issue is about composition, and I think there is a bug.

I will check if we can update this PR or need to open a new one.

awaelchli avatar Aug 10 '22 10:08 awaelchli