FedML icon indicating copy to clipboard operation
FedML copied to clipboard

Problem with the function " _local_test_on_all_clients" in "https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py"

Open shubham22124 opened this issue 2 years ago • 3 comments

def _local_test_on_all_clients(self, round_idx):

    logging.info("################local_test_on_all_clients : {}".format(round_idx))

    train_metrics = {"num_samples": [], "num_correct": [], "losses": []}

    test_metrics = {"num_samples": [], "num_correct": [], "losses": []}

    **client = self.client_list[0]**

    for client_idx in range(self.args.client_num_in_total):
        """
        Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
        the training client number is larger than the testing client number
        """
        if self.test_data_local_dict[client_idx] is None:
            continue
        client.update_local_dataset(
            0,
            self.train_data_local_dict[client_idx],
            self.test_data_local_dict[client_idx],
            self.train_data_local_num_dict[client_idx],
        )
        # train data
        train_local_metrics = client.local_test(False)
        train_metrics["num_samples"].append(copy.deepcopy(train_local_metrics["test_total"]))
        train_metrics["num_correct"].append(copy.deepcopy(train_local_metrics["test_correct"]))
        train_metrics["losses"].append(copy.deepcopy(train_local_metrics["test_loss"]))

        # test data
        test_local_metrics = client.local_test(True)
        test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
        test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
        test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))

    # test on training dataset
    train_acc = sum(train_metrics["num_correct"]) / sum(train_metrics["num_samples"])
    train_loss = sum(train_metrics["losses"]) / sum(train_metrics["num_samples"])

    # test on test dataset
    test_acc = sum(test_metrics["num_correct"]) / sum(test_metrics["num_samples"])
    test_loss = sum(test_metrics["losses"]) / sum(test_metrics["num_samples"])

    stats = {"training_acc": train_acc, "training_loss": train_loss}
    if self.args.enable_wandb:
        wandb.log({"Train/Acc": train_acc, "round": round_idx})
        wandb.log({"Train/Loss": train_loss, "round": round_idx})

    mlops.log({"Train/Acc": train_acc, "round": round_idx})
    mlops.log({"Train/Loss": train_loss, "round": round_idx})
    logging.info(stats)

    stats = {"test_acc": test_acc, "test_loss": test_loss}
    if self.args.enable_wandb:
        wandb.log({"Test/Acc": test_acc, "round": round_idx})
        wandb.log({"Test/Loss": test_loss, "round": round_idx})

    mlops.log({"Test/Acc": test_acc, "round": round_idx})
    mlops.log({"Test/Loss": test_loss, "round": round_idx})
    logging.info(stats)

In the 4th line of the function, why is always the zeroth client selected? This way, the testing happens on the model corresponding to the zeroth client only, but we want the average test error on the local dataset for each client, isn't it?

shubham22124 avatar Nov 07 '23 15:11 shubham22124

@shubham22124 Thank you for asking this question. However, in that line, we just get the general client state (e.g., model) from the first client. The evaluation still happens across all clients (see client_idx) as shown in line 195: https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py#L195

fedml-dimitris avatar Nov 08 '23 22:11 fedml-dimitris

But line 195 just updates the dataset. Shouldn't the model be updated as well, as each client undergoes local training and has a different model than the model received from the server?

shubham22124 avatar Nov 09 '23 04:11 shubham22124

@shubham22124 So basically, lines 193-198 is where the global model is being evaluated against the local dataset of each client, so every client's model is the same, hence the client = self.client_list[0]. In other words, the evaluation of the global model is rotated to each client's dataset.

fedml-dimitris avatar Nov 09 '23 19:11 fedml-dimitris