keras
keras copied to clipboard
PR for solving issue #16797
A minimal change that solves the issue is modifying the method keras.layers.Layer.trainable_weights(), in base_layer.py, by replacing it with the analogous method of keras.Model.
But still I have some questions: Does it make sense to have two different methods for gathering weights, one for Models and other for Layers? Shouldn't they be the same?
I did not update the non_trainable_weights method, but probably with them the same thing should be done.
I don't feel very confident with this as I don't fully understand the internal working of keras. I just checked where the discrepancy was and tried the obvious thing: copying from where it works (Model) to where it doesn't (Layer).
Thanks for the PR!
I suspect that the reason we have different logic in Layer and Model is to preserve weight ordering backwards compatibility.
How does this change impact weight ordering, e.g. for a layer that includes a sublayer that includes BatchNorm layers?
If weight ordering is not generally impacted then I would suggested moving the logic for weights/trainable_weights/non_trainable_weights as well as the deduplication logic to the Layer class. Model can just inherit from it.
Thanks a lot, @fchollet. I've tried your suggestion, you can take a look at this gist if you want:
https://gist.github.com/JaimeArboleda/7d446a9ba01a9929d603f6a084014741
You are right that the order of weights is affected, so if this has bad consequences (please, confirm me so) I'll try to find another solution, maybe modifying the existing code in a less drastic way...
PS: I know is not the most proper place to say that, but if you don't mind I'd love to thank you for (among other things) your wonderful book on Deep Learning and your terrific tutorials publicly available in the keras docs. I learned a lot from them!
If you take a look at my last commit (if this PR is getting too messy, I can delete it and create a new one), I think I have a better solution that does not affect order of weights.
Basically, I created a method that is called in build() and that updates variables contained in wrapped objects (that might have changed after the creation of the container, a dict or list):
def _update_trackables(self):
"""This method loops over tracked objects to update tracked variables
for wrapped lists/dicts that may have been updated after initialization
(Issue #16797)
"""
for trackable_obj in self._self_tracked_trackables:
if isinstance(
trackable_obj,
tf.__internal__.tracking.TrackableDataStructure
):
self._track_variables(trackable_obj)
PS: I know is not the most proper place to say that, but if you don't mind I'd love to thank you for (among other things) your wonderful book on Deep Learning and your terrific tutorials publicly available in the keras docs. I learned a lot from them!
Glad the book was useful! Thanks for the kind words.
I did not know that we could not rely on every layer being built, or even that in the __call__ new weights could be added to a wrapped list/dict.
With respect to the overhead, it will depend on:
- How many weights the model contains, as we would be looping over
self._self_tracked_trackablesof each layer. - How many of those
tracked_trackablesareListWrapperorDictWrapper, as for them we would be performing an extra inner loop to update the variables inside the wrappers.
So if we had a model with many layers, each layer containing lots of trackables and/or wrappers, the overhead could be noticeable. I think that this is not common in general, not even in large models (those with billions of trainable parameters), as large models tend to have their parameters contained in big variables, so the ratio of trackables against parameters is very small.
I can provide two examples run in a Google Colab instance:
- Example 1: GPT-2 from Huggingface, with 124 M parameters. It takes around 0.3 milliseconds to perform an update in all of its components (layers, trackables, etc.) using the proposed code. It should be noted that this model has 148 weights in total.
- Example 2: a model containing 1000 layers, each of them containing 10 inner layers, and each inner layer containing two weights. In this case, to make things "worse", everything (layers, inner layers and weights) is contained in ListWrappers. It takes 300 milliseconds to perform an update in all of its components. It should be noted that this model contains 60 K parameters and 20 K weights.
Given those numbers, do you think we need caching solutions?
On the other hand, I thought about another option, but it will imply more changes to the codebase:
- When we create a ListWrapper or a DictWrapper, we add a reference to the parent container of this object (if my understanding is correct, it could be a Layer or another Wrapper).
- When the Wrapper is changed (for example, when appending a value to a ListWrapper), we notify the parent container by calling a specific method that will update its weights and propagate the message upwards until we reach a Layer.
This way, we'll ensure that every change is inmediately taken into account, and we won't need to run all those loops just in case.
What do you think?
Sorry for the delay in reviewing this.
On the other hand, I thought about another option
I think simply having the content of the current def _update_trackables(self): get called in trainable_weights / non_trainable_weights should be sufficient. The overhead does not seem to be significant.
Thank you!
Ok, thanks a lot for reviewing it. I agree with you. I've made a last commit doing exactly what you say.