federated
federated copied to clipboard
Pretrained model: Non-trainable weights seems to be updated during FL training
Describe the bug I am currently experimenting with TFF in combination with my own multi-class dataset. I am using the high-level API, following the classification tutorial.
My central keras model (resnet50, pretrained on imagenet, only the top layers aretrainable) performs quite well, after 3 epochs train accuracy and test accuracy is above 95%. Using the same model setup with TFF and very good starting conditions (2 clients, data is split equally, stratified classes) leads to very poor results (validation accuracy "oscillate" between 30 and 50% after a lot of epochs) In the same federated setup, a very simple NN with only one dense layer immediately results in a validation accuracy > 80%.
So it seems that there is an issue with the pretrained resnet50 model. My feeling is that the "non-trainable" parameters get updated although this shouldn't be the case.
Comparing the evaluation model of the latest FL run with a model that contains the initial weights shows a difference:
My model architecture looks like this:
What is also suspicious: Although there are only 6.000 trainable parameters, the federated training takes a lot longer than with the simple model that requires a training of 400.000 params
I also getthe following warnings when calling the "build_federated_averaging_process" method. I don't get those warnings when compiling/training the model in a central setting.
Environment (please complete the following information):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Amazon Linux AMI
- Python package versions (e.g., TensorFlow Federated, TensorFlow):
- tensorflow: 2.2.0
- tensorflow-federated: 0.16.1
- Python version: 3.6.10
- Bazel version (if building from source): N/A
- CUDA/cuDNN version: N/A
- What TensorFlow Federated execution stack are you using?
Expected behavior Only the trainable parameters shall be adjusted during federated learning. Ultimately, I would expect that my pretrained model perfoms a lot better.
Additional context Add any other context about the problem here.
Hi @mirkohin1991. Can you give more details on the ResNet-50 layers? In particular, I suspect that BatchNorm may be causing problems. It's worth noting that BatchNorm will exhibit much different behavior in federated settings than centralized settings (see https://arxiv.org/abs/1910.00189 for more discussion of this fact).
Hi @zcharles8 . I figured out that I had to explicitly set the weights again after initializing the server state:
server_state = tff.learning.state_with_new_model_weights( server_state, trainable_weights=[v.numpy() for v in eval_model.trainable_weights], non_trainable_weights=[ v.numpy() for v in eval_model.non_trainable_weights ])
My intuituion was that, when I build the federated averaging with a model that contains non-trainable paramerters, this setting should be considered. But that was not the case.
Now it seems to work as expected. Maybe you can add that somewhere in the documentation?
Besides that, I am still getting the following warnings:
How to solve that?