flower icon indicating copy to clipboard operation
flower copied to clipboard

FedOpt strategies: Central model gives NaN after second aggregation

Open sancarlim opened this issue 4 years ago • 10 comments

Describe the bug

When I use any of the FedOpt strategies (FedAdam, FedYogi, FedAdagrad) it seems very unstable, the model outputs NaNs after second/third aggregation.

Steps/Code to Reproduce

Strategy:

   strategy = fl.server.strategy.FedYogi(
        fraction_fit = fc/ac,
        fraction_eval = 0.2, # not used - no federated evaluation
        min_fit_clients = fc,
        min_eval_clients = 2, # not used 
        min_available_clients = ac,
        eval_fn=get_eval_fn(model),
        on_fit_config_fn=fit_config,
        on_evaluate_config_fn=evaluate_config,
        initial_parameters=fl.common.weights_to_parameters(model_weights), 
    ) 

def get_eval_fn(model):
    """Return an evaluation function for server-side evaluation."""
    _, testset, _ = utils.load_isic_by_patient_server() 
    testloader = DataLoader(testset, batch_size=16, num_workers=4, worker_init_fn=utils.seed_worker, shuffle = False) 

    # The `evaluate` function will be called after every round
    def evaluate(
        weights: fl.common.Weights,
    ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        # Update model with the latest parameters
        set_parameters(model, weights) 
        loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss()) 

        return float(loss), {"accuracy": float(accuracy), "auc": float(auc)}

    return evaluate
   

def fit_config(rnd: int):
    """Return training configuration dict for each round.
    Keep batch size fixed at 32, perform two rounds of training with one
    local epoch, increase to two local epochs afterwards.
    """
    config = {
        "batch_size": 32,
        "local_epochs": 1 if rnd < 2 else 2,
    }
    return config


def evaluate_config(rnd: int):
    """Return evaluation configuration dict for each round.
    Perform five local evaluation steps on each client (i.e., use five
    batches) during rounds one to three, then increase to ten local
    evaluation steps.
    """
    val_steps = 5 if rnd < 4 else 10
    return {"val_steps": val_steps}

I have tried with the default eval_fn and defaults fit and evaluate configs, and the behavior changes but still gives NaNs in the end.

Dataset: ISIC 2020 https://www.kaggle.com/c/siim-isic-melanoma-classification/data I have tested with fc=ac=3 and fc=ac=2. In the latter, one of the clients has a training set of ~2k images and the other ~10.5k. Model: EfficientNetB2.

Code: https://github.com/sandracl72/flower server_advanced.py (--nowandb) client_isic.py --partition 0 (--nowandb) client_isic.py --partition 1 (--nowandb)

Expected Results

The server aggregated the weights of all clients.

Actual Results

INFO flower 2022-02-11 09:28:34,874 | app.py:109 | Flower server running (10 rounds) SSL is disabled INFO flower 2022-02-11 09:28:34,875 | server.py:118 | Initializing global parameters INFO flower 2022-02-11 09:28:34,875 | server.py:301 | Using initial parameters provided by strategy INFO flower 2022-02-11 09:28:34,875 | server.py:120 | Evaluating initial parameters INFO flower 2022-02-11 09:28:54,419 | server.py:123 | initial parameters (loss, other metrics): 0.6916685566558676, {'accuracy': 0.5218716861081655, 'auc': 0.48442367381213036} INFO flower 2022-02-11 09:28:54,419 | server.py:133 | FL starting DEBUG flower 2022-02-11 09:28:54,419 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2022-02-11 09:30:50,949 | server.py:261 | fit_round received 2 results and 0 failures INFO flower 2022-02-11 09:31:09,522 | server.py:148 | fit progress: (1, 0.32495464879074293, {'accuracy': 0.901643690349947, 'auc': 0.8418079867528361}, 135.10243083909154) INFO flower 2022-02-11 09:31:09,522 | server.py:199 | evaluate_round: no clients selected, cancel DEBUG flower 2022-02-11 09:31:09,522 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2022-02-11 09:33:03,801 | server.py:261 | fit_round received 2 results and 0 failures Traceback (most recent call last): File "server_advanced.py", line 130, in fl.server.start_server("0.0.0.0:8080", config={"num_rounds": rounds}, strategy=strategy) File "/workspace/flower/src/py/flwr/server/app.py", line 111, in start_server hist = _fl( File "/workspace/flower/src/py/flwr/server/app.py", line 148, in _fl hist = server.fit(num_rounds=config["num_rounds"]) File "/workspace/flower/src/py/flwr/server/server.py", line 145, in fit res_cen = self.strategy.evaluate(parameters=self.parameters) File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 178, in evaluate eval_res = self.eval_fn(weights) File "server_advanced.py", line 47, in evaluate loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss()) File "/workspace/flower/utils.py", line 468, in val val_accuracy = accuracy_score(val_gt2, torch.round(pred2)) File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f return f(**kwargs) File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 187, in accuracy_score y_type, y_true, y_pred = _check_targets(y_true, y_pred) File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 83, in _check_targets type_pred = type_of_target(y_pred) File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py", line 287, in type_of_target _assert_all_finite(y) File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 96, in _assert_all_finite raise ValueError( ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

sancarlim avatar Feb 11 '22 09:02 sancarlim

Hi @sandracl72 , Thanks for this. I'm guessing you have already tried with FedAvg and it works, right? Could you run a few tests for me, please? In your FedYogi, could you please print the outputs of the following:

  • fedavg_weights_aggregate in line 135
  • delta_t in lines 136 and 140
  • m_t in line 141 and 147
  • v_t in lines 151 and 155
  • new_weights in line 160

Basically what I'd need to see are the values of those variable before and after the transformations. Thanks

pedropgusmao avatar Feb 14 '22 08:02 pedropgusmao

Thank you for your answer @pedropgusmao . Yes, I've performed several experiments using FedAvg and never had this issue before.

Here you can access the output log: https://drive.google.com/file/d/1TWg7EgCHFbDMID7SaW78B1vNnfemWyjZ/view?usp=sharing

Thanks!

sancarlim avatar Feb 14 '22 10:02 sancarlim

Hi @sandracl72, thanks for this. From the log file I don't see a NaN on the prints, but only in; File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 178, in evaluate. Could you try and remove the previous prints and just print parameters right at the beginning of this function instead please? https://github.com/adap/flower/blob/5fb4f9c1cd0070495049f45f382407e9d95166cd/src/py/flwr/server/strategy/fedavg.py#L173

pedropgusmao avatar Feb 15 '22 10:02 pedropgusmao

Hi @pedropgusmao , I printed sum([np.isnan(w).sum() for w in weights ]) before the evaluation to detect if the parameters being evaluated have NaN, but apparently they haven't. This is the output log:

INFO flower 2022-02-16 08:30:10,677 | app.py:109 | Flower server running (10 rounds)
SSL is disabled
INFO flower 2022-02-16 08:30:10,678 | server.py:118 | Initializing global parameters
INFO flower 2022-02-16 08:30:10,678 | server.py:301 | Using initial parameters provided by strategy
INFO flower 2022-02-16 08:30:10,678 | server.py:120 | Evaluating initial parameters
INFO flower 2022-02-16 08:30:10,678 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:30:10,797 | fedavg.py:182 | Number of weights with NaN value: 0
INFO flower 2022-02-16 08:30:30,064 | server.py:123 | initial parameters (loss, other metrics): 0.6916685566558676, {'accuracy': 0.5218716861081655, 'auc': 0.48442367381213036}
INFO flower 2022-02-16 08:30:30,064 | server.py:133 | FL starting
DEBUG flower 2022-02-16 08:32:52,575 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-16 08:33:44,737 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-16 08:33:45,329 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:33:45,438 | fedavg.py:182 | Number of weights with NaN value: 0
INFO flower 2022-02-16 08:34:03,231 | server.py:148 | fit progress: (1, 0.3569583234853719, {'accuracy': 0.8939554612937434, 'auc': 0.7939201167759478}, 213.1666322145611)
INFO flower 2022-02-16 08:34:03,232 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-02-16 08:34:03,233 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-16 08:34:52,459 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-16 08:34:53,040 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:34:53,155 | fedavg.py:182 | Number of weights with NaN value: 0
Traceback (most recent call last):
  File "server_advanced.py", line 135, in <module>
    fl.server.start_server("0.0.0.0:8080", config={"num_rounds": rounds}, strategy=strategy)
  File "/workspace/flower/src/py/flwr/server/app.py", line 111, in start_server
    hist = _fl(
  File "/workspace/flower/src/py/flwr/server/app.py", line 148, in _fl
    hist = server.fit(num_rounds=config["num_rounds"])
  File "/workspace/flower/src/py/flwr/server/server.py", line 145, in fit
    res_cen = self.strategy.evaluate(parameters=self.parameters)
  File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 183, in evaluate
    eval_res = self.eval_fn(weights)
  File "server_advanced.py", line 51, in evaluate
    loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss())
  File "/workspace/flower/utils.py", line 468, in val
    val_accuracy = accuracy_score(val_gt2, torch.round(pred2))
  File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 187, in accuracy_score
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 83, in _check_targets
    type_pred = type_of_target(y_pred)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py", line 287, in type_of_target
    _assert_all_finite(y)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 96, in _assert_all_finite
    raise ValueError(
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

So it seems the NaNs appear in the forward pass after the second aggregation.

Here you can see the histograms of the weights and gradients of the aggregated model: https://wandb.ai/eyeforai/dai-healthcare/reports/FedOpt-Debug--VmlldzoxNTY4ODQ1?accessToken=4kqdpvadvojfpd8my9iflqsw0zk7d9d0xativ7eu5ad69t17ovi7fm91v2co0oy4

sancarlim avatar Feb 16 '22 07:02 sancarlim

Hi @sandracl72 , I might be wrong, but could it be that evaluation is performing something weird? From the log I see INFO flower 2022-02-16 08:34:53,155 | fedavg.py:182 | Number of weights with NaN value: 0 right before evaluation actually starts. Could you also use the same isnan method to print inside the https://github.com/adap/flower/blob/d80c8c2738b79badbc820f940efcd7fb4fff9503/src/py/flwr/server/strategy/fedyogi.py#L161 , please ? Thanks

pedropgusmao avatar Feb 16 '22 08:02 pedropgusmao

@danieljanes this might be related to BatchNorm. Do you remember if we had any example with aggregation and batch norm?

pedropgusmao avatar Feb 16 '22 10:02 pedropgusmao

@danieljanes the weights aren't NaN before or after the aggregation, but after the forward pass (in the second evaluation round) the results are NaN. We thought it might be related to the BN, and now I have found that if I don't aggregate the "bn" weights, it works fine.

Changes:

def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if 'bn' not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

def get_parameters(self) -> List[np.ndarray]:  
        return [val.cpu().numpy() for name, val in self.model.state_dict().items() if 'bn' not in name]

sancarlim avatar Feb 17 '22 08:02 sancarlim

This is being investigated, but for now, I'd recommend:

  • Changing swapping BatchNorms for GroupNorm layers.
  • Make sure grpcio==1.43 in poetry.lock. If not, run poetry add grpcio==1.43. This is related to https://github.com/ray-project/ray/issues/22518

pedropgusmao avatar Feb 27 '22 10:02 pedropgusmao

Now that the 0.18 release is done there's more room to investigate this. @sandracl72 , would it be possible to have a repo that reproduces the error in a minimalistic way with Flower 0.18? If @pedropgusmao doesn't already have one, that is.

danieljanes avatar Mar 01 '22 10:03 danieljanes

Sure, here you have a minimal example using Flower 0.18, with only 50 images: https://github.com/sandracl72/flower_fedopt_debug.git

You have to run it with --nowandb arg. I've left it in case you want to log some metrics in your own project. with wandb.watch(model, log="all") you can track the weights and grads, which could be useful for debugging. You can use directly run.sh.

Thanks !

sancarlim avatar Mar 02 '22 08:03 sancarlim