federated icon indicating copy to clipboard operation
federated copied to clipboard

Pretrained model: Non-trainable weights seems to be updated during FL training

Open mirkohin1991 opened this issue 4 years ago • 2 comments

Describe the bug I am currently experimenting with TFF in combination with my own multi-class dataset. I am using the high-level API, following the classification tutorial.

My central keras model (resnet50, pretrained on imagenet, only the top layers aretrainable) performs quite well, after 3 epochs train accuracy and test accuracy is above 95%. Using the same model setup with TFF and very good starting conditions (2 clients, data is split equally, stratified classes) leads to very poor results (validation accuracy "oscillate" between 30 and 50% after a lot of epochs) In the same federated setup, a very simple NN with only one dense layer immediately results in a validation accuracy > 80%.

So it seems that there is an issue with the pretrained resnet50 model. My feeling is that the "non-trainable" parameters get updated although this shouldn't be the case. Comparing the evaluation model of the latest FL run with a model that contains the initial weights shows a difference: weights

My model architecture looks like this: image

What is also suspicious: Although there are only 6.000 trainable parameters, the federated training takes a lot longer than with the simple model that requires a training of 400.000 params

I also getthe following warnings when calling the "build_federated_averaging_process" method. I don't get those warnings when compiling/training the model in a central setting. image

Environment (please complete the following information):

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Amazon Linux AMI
  • Python package versions (e.g., TensorFlow Federated, TensorFlow):
    • tensorflow: 2.2.0
    • tensorflow-federated: 0.16.1
  • Python version: 3.6.10
  • Bazel version (if building from source): N/A
  • CUDA/cuDNN version: N/A
  • What TensorFlow Federated execution stack are you using?

Expected behavior Only the trainable parameters shall be adjusted during federated learning. Ultimately, I would expect that my pretrained model perfoms a lot better.

Additional context Add any other context about the problem here.

mirkohin1991 avatar Sep 15 '20 15:09 mirkohin1991

Hi @mirkohin1991. Can you give more details on the ResNet-50 layers? In particular, I suspect that BatchNorm may be causing problems. It's worth noting that BatchNorm will exhibit much different behavior in federated settings than centralized settings (see https://arxiv.org/abs/1910.00189 for more discussion of this fact).

zcharles8 avatar Sep 22 '20 18:09 zcharles8

Hi @zcharles8 . I figured out that I had to explicitly set the weights again after initializing the server state:

server_state = tff.learning.state_with_new_model_weights( server_state, trainable_weights=[v.numpy() for v in eval_model.trainable_weights], non_trainable_weights=[ v.numpy() for v in eval_model.non_trainable_weights ])

My intuituion was that, when I build the federated averaging with a model that contains non-trainable paramerters, this setting should be considered. But that was not the case.

Now it seems to work as expected. Maybe you can add that somewhere in the documentation?

Besides that, I am still getting the following warnings: image How to solve that?

mirkohin1991 avatar Sep 25 '20 06:09 mirkohin1991