keras icon indicating copy to clipboard operation
keras copied to clipboard

Pass config to constructor when reviving custom functional model

Open TrAyZeN opened this issue 1 year ago • 2 comments

Currently, when loading a model instantiated from a custom Model subclass its config is not passed to it's constructor. This leads to some parameters not being restored.

Here is a snippet showing the behavior mentioned:

class CustomModel(Model):
    def __init__(self, *args, param=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.param = param

    def get_config(self):
        base_config = super().get_config()
        config = {"param": self.param}
        return base_config | config

inputs = layers.Input((3,))
outputs = layers.Dense(5)(inputs)
model = CustomModel(inputs=inputs, outputs=outputs, param=3)

new_model = CustomModel.from_config(model.get_config())
print(new_model.param) # prints 1 currently i.e. default value of param

This PR proposes to fix this issue by passing config in functional_from_config to the model constructor.

TrAyZeN avatar Oct 03 '24 16:10 TrAyZeN

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Oct 03 '24 16:10 google-cla[bot]

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 78.87%. Comparing base (ca88613) to head (0be55ca). Report is 6 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20321      +/-   ##
==========================================
+ Coverage   78.81%   78.87%   +0.06%     
==========================================
  Files         512      513       +1     
  Lines       49056    49236     +180     
  Branches     9033     9076      +43     
==========================================
+ Hits        38664    38837     +173     
- Misses       8528     8532       +4     
- Partials     1864     1867       +3     
Flag Coverage Δ
keras 78.73% <100.00%> (+0.06%) :arrow_up:
keras-jax 62.38% <100.00%> (+0.11%) :arrow_up:
keras-numpy 57.40% <100.00%> (-0.01%) :arrow_down:
keras-tensorflow 63.64% <100.00%> (+0.08%) :arrow_up:
keras-torch 62.37% <100.00%> (+0.11%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Oct 03 '24 17:10 codecov-commenter