FedOpt strategies: Central model gives NaN after second aggregation
Describe the bug
When I use any of the FedOpt strategies (FedAdam, FedYogi, FedAdagrad) it seems very unstable, the model outputs NaNs after second/third aggregation.
Steps/Code to Reproduce
Strategy:
strategy = fl.server.strategy.FedYogi(
fraction_fit = fc/ac,
fraction_eval = 0.2, # not used - no federated evaluation
min_fit_clients = fc,
min_eval_clients = 2, # not used
min_available_clients = ac,
eval_fn=get_eval_fn(model),
on_fit_config_fn=fit_config,
on_evaluate_config_fn=evaluate_config,
initial_parameters=fl.common.weights_to_parameters(model_weights),
)
def get_eval_fn(model):
"""Return an evaluation function for server-side evaluation."""
_, testset, _ = utils.load_isic_by_patient_server()
testloader = DataLoader(testset, batch_size=16, num_workers=4, worker_init_fn=utils.seed_worker, shuffle = False)
# The `evaluate` function will be called after every round
def evaluate(
weights: fl.common.Weights,
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
# Update model with the latest parameters
set_parameters(model, weights)
loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss())
return float(loss), {"accuracy": float(accuracy), "auc": float(auc)}
return evaluate
def fit_config(rnd: int):
"""Return training configuration dict for each round.
Keep batch size fixed at 32, perform two rounds of training with one
local epoch, increase to two local epochs afterwards.
"""
config = {
"batch_size": 32,
"local_epochs": 1 if rnd < 2 else 2,
}
return config
def evaluate_config(rnd: int):
"""Return evaluation configuration dict for each round.
Perform five local evaluation steps on each client (i.e., use five
batches) during rounds one to three, then increase to ten local
evaluation steps.
"""
val_steps = 5 if rnd < 4 else 10
return {"val_steps": val_steps}
I have tried with the default eval_fn and defaults fit and evaluate configs, and the behavior changes but still gives NaNs in the end.
Dataset: ISIC 2020 https://www.kaggle.com/c/siim-isic-melanoma-classification/data I have tested with fc=ac=3 and fc=ac=2. In the latter, one of the clients has a training set of ~2k images and the other ~10.5k. Model: EfficientNetB2.
Code: https://github.com/sandracl72/flower server_advanced.py (--nowandb) client_isic.py --partition 0 (--nowandb) client_isic.py --partition 1 (--nowandb)
Expected Results
The server aggregated the weights of all clients.
Actual Results
INFO flower 2022-02-11 09:28:34,874 | app.py:109 | Flower server running (10 rounds)
SSL is disabled
INFO flower 2022-02-11 09:28:34,875 | server.py:118 | Initializing global parameters
INFO flower 2022-02-11 09:28:34,875 | server.py:301 | Using initial parameters provided by strategy
INFO flower 2022-02-11 09:28:34,875 | server.py:120 | Evaluating initial parameters
INFO flower 2022-02-11 09:28:54,419 | server.py:123 | initial parameters (loss, other metrics): 0.6916685566558676, {'accuracy': 0.5218716861081655, 'auc': 0.48442367381213036}
INFO flower 2022-02-11 09:28:54,419 | server.py:133 | FL starting
DEBUG flower 2022-02-11 09:28:54,419 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-11 09:30:50,949 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-11 09:31:09,522 | server.py:148 | fit progress: (1, 0.32495464879074293, {'accuracy': 0.901643690349947, 'auc': 0.8418079867528361}, 135.10243083909154)
INFO flower 2022-02-11 09:31:09,522 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-02-11 09:31:09,522 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-11 09:33:03,801 | server.py:261 | fit_round received 2 results and 0 failures
Traceback (most recent call last):
File "server_advanced.py", line 130, in
Hi @sandracl72 ,
Thanks for this. I'm guessing you have already tried with FedAvg and it works, right?
Could you run a few tests for me, please?
In your FedYogi, could you please print the outputs of the following:
-
fedavg_weights_aggregatein line 135 -
delta_tin lines 136 and 140 -
m_tin line 141 and 147 -
v_tin lines 151 and 155 -
new_weightsin line 160
Basically what I'd need to see are the values of those variable before and after the transformations. Thanks
Thank you for your answer @pedropgusmao . Yes, I've performed several experiments using FedAvg and never had this issue before.
Here you can access the output log: https://drive.google.com/file/d/1TWg7EgCHFbDMID7SaW78B1vNnfemWyjZ/view?usp=sharing
Thanks!
Hi @sandracl72, thanks for this.
From the log file I don't see a NaN on the prints, but only in;
File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 178, in evaluate.
Could you try and remove the previous prints and just print parameters right at the beginning of this function instead please? https://github.com/adap/flower/blob/5fb4f9c1cd0070495049f45f382407e9d95166cd/src/py/flwr/server/strategy/fedavg.py#L173
Hi @pedropgusmao ,
I printed sum([np.isnan(w).sum() for w in weights ]) before the evaluation to detect if the parameters being evaluated have NaN, but apparently they haven't. This is the output log:
INFO flower 2022-02-16 08:30:10,677 | app.py:109 | Flower server running (10 rounds)
SSL is disabled
INFO flower 2022-02-16 08:30:10,678 | server.py:118 | Initializing global parameters
INFO flower 2022-02-16 08:30:10,678 | server.py:301 | Using initial parameters provided by strategy
INFO flower 2022-02-16 08:30:10,678 | server.py:120 | Evaluating initial parameters
INFO flower 2022-02-16 08:30:10,678 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:30:10,797 | fedavg.py:182 | Number of weights with NaN value: 0
INFO flower 2022-02-16 08:30:30,064 | server.py:123 | initial parameters (loss, other metrics): 0.6916685566558676, {'accuracy': 0.5218716861081655, 'auc': 0.48442367381213036}
INFO flower 2022-02-16 08:30:30,064 | server.py:133 | FL starting
DEBUG flower 2022-02-16 08:32:52,575 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-16 08:33:44,737 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-16 08:33:45,329 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:33:45,438 | fedavg.py:182 | Number of weights with NaN value: 0
INFO flower 2022-02-16 08:34:03,231 | server.py:148 | fit progress: (1, 0.3569583234853719, {'accuracy': 0.8939554612937434, 'auc': 0.7939201167759478}, 213.1666322145611)
INFO flower 2022-02-16 08:34:03,232 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-02-16 08:34:03,233 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-16 08:34:52,459 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-16 08:34:53,040 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:34:53,155 | fedavg.py:182 | Number of weights with NaN value: 0
Traceback (most recent call last):
File "server_advanced.py", line 135, in <module>
fl.server.start_server("0.0.0.0:8080", config={"num_rounds": rounds}, strategy=strategy)
File "/workspace/flower/src/py/flwr/server/app.py", line 111, in start_server
hist = _fl(
File "/workspace/flower/src/py/flwr/server/app.py", line 148, in _fl
hist = server.fit(num_rounds=config["num_rounds"])
File "/workspace/flower/src/py/flwr/server/server.py", line 145, in fit
res_cen = self.strategy.evaluate(parameters=self.parameters)
File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 183, in evaluate
eval_res = self.eval_fn(weights)
File "server_advanced.py", line 51, in evaluate
loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss())
File "/workspace/flower/utils.py", line 468, in val
val_accuracy = accuracy_score(val_gt2, torch.round(pred2))
File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
return f(**kwargs)
File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 187, in accuracy_score
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 83, in _check_targets
type_pred = type_of_target(y_pred)
File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py", line 287, in type_of_target
_assert_all_finite(y)
File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 96, in _assert_all_finite
raise ValueError(
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').
So it seems the NaNs appear in the forward pass after the second aggregation.
Here you can see the histograms of the weights and gradients of the aggregated model: https://wandb.ai/eyeforai/dai-healthcare/reports/FedOpt-Debug--VmlldzoxNTY4ODQ1?accessToken=4kqdpvadvojfpd8my9iflqsw0zk7d9d0xativ7eu5ad69t17ovi7fm91v2co0oy4
Hi @sandracl72 , I might be wrong, but could it be that evaluation is performing something weird? From the log I see
INFO flower 2022-02-16 08:34:53,155 | fedavg.py:182 | Number of weights with NaN value: 0 right before evaluation actually starts.
Could you also use the same isnan method to print inside the https://github.com/adap/flower/blob/d80c8c2738b79badbc820f940efcd7fb4fff9503/src/py/flwr/server/strategy/fedyogi.py#L161 , please ?
Thanks
@danieljanes this might be related to BatchNorm. Do you remember if we had any example with aggregation and batch norm?
@danieljanes the weights aren't NaN before or after the aggregation, but after the forward pass (in the second evaluation round) the results are NaN. We thought it might be related to the BN, and now I have found that if I don't aggregate the "bn" weights, it works fine.
Changes:
def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Set model parameters from a list of NumPy ndarrays
keys = [k for k in self.model.state_dict().keys() if 'bn' not in k]
params_dict = zip(keys, parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=False)
def get_parameters(self) -> List[np.ndarray]:
return [val.cpu().numpy() for name, val in self.model.state_dict().items() if 'bn' not in name]
This is being investigated, but for now, I'd recommend:
- Changing swapping
BatchNorms forGroupNormlayers. - Make sure
grpcio==1.43inpoetry.lock. If not, runpoetry add grpcio==1.43. This is related to https://github.com/ray-project/ray/issues/22518
Now that the 0.18 release is done there's more room to investigate this. @sandracl72 , would it be possible to have a repo that reproduces the error in a minimalistic way with Flower 0.18? If @pedropgusmao doesn't already have one, that is.
Sure, here you have a minimal example using Flower 0.18, with only 50 images: https://github.com/sandracl72/flower_fedopt_debug.git
You have to run it with --nowandb arg. I've left it in case you want to log some metrics in your own project. with wandb.watch(model, log="all") you can track the weights and grads, which could be useful for debugging.
You can use directly run.sh.
Thanks !