keras
keras copied to clipboard
Pass config to constructor when reviving custom functional model
Currently, when loading a model instantiated from a custom Model subclass its config is not passed to it's constructor. This leads to some parameters not being restored.
Here is a snippet showing the behavior mentioned:
class CustomModel(Model):
def __init__(self, *args, param=1, **kwargs):
super().__init__(*args, **kwargs)
self.param = param
def get_config(self):
base_config = super().get_config()
config = {"param": self.param}
return base_config | config
inputs = layers.Input((3,))
outputs = layers.Dense(5)(inputs)
model = CustomModel(inputs=inputs, outputs=outputs, param=3)
new_model = CustomModel.from_config(model.get_config())
print(new_model.param) # prints 1 currently i.e. default value of param
This PR proposes to fix this issue by passing config in functional_from_config to the model constructor.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 78.87%. Comparing base (
ca88613) to head (0be55ca). Report is 6 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #20321 +/- ##
==========================================
+ Coverage 78.81% 78.87% +0.06%
==========================================
Files 512 513 +1
Lines 49056 49236 +180
Branches 9033 9076 +43
==========================================
+ Hits 38664 38837 +173
- Misses 8528 8532 +4
- Partials 1864 1867 +3
| Flag | Coverage Δ | |
|---|---|---|
| keras | 78.73% <100.00%> (+0.06%) |
:arrow_up: |
| keras-jax | 62.38% <100.00%> (+0.11%) |
:arrow_up: |
| keras-numpy | 57.40% <100.00%> (-0.01%) |
:arrow_down: |
| keras-tensorflow | 63.64% <100.00%> (+0.08%) |
:arrow_up: |
| keras-torch | 62.37% <100.00%> (+0.11%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.