Problem with the function " _local_test_on_all_clients" in "https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py"
def _local_test_on_all_clients(self, round_idx):
logging.info("################local_test_on_all_clients : {}".format(round_idx))
train_metrics = {"num_samples": [], "num_correct": [], "losses": []}
test_metrics = {"num_samples": [], "num_correct": [], "losses": []}
**client = self.client_list[0]**
for client_idx in range(self.args.client_num_in_total):
"""
Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
the training client number is larger than the testing client number
"""
if self.test_data_local_dict[client_idx] is None:
continue
client.update_local_dataset(
0,
self.train_data_local_dict[client_idx],
self.test_data_local_dict[client_idx],
self.train_data_local_num_dict[client_idx],
)
# train data
train_local_metrics = client.local_test(False)
train_metrics["num_samples"].append(copy.deepcopy(train_local_metrics["test_total"]))
train_metrics["num_correct"].append(copy.deepcopy(train_local_metrics["test_correct"]))
train_metrics["losses"].append(copy.deepcopy(train_local_metrics["test_loss"]))
# test data
test_local_metrics = client.local_test(True)
test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))
# test on training dataset
train_acc = sum(train_metrics["num_correct"]) / sum(train_metrics["num_samples"])
train_loss = sum(train_metrics["losses"]) / sum(train_metrics["num_samples"])
# test on test dataset
test_acc = sum(test_metrics["num_correct"]) / sum(test_metrics["num_samples"])
test_loss = sum(test_metrics["losses"]) / sum(test_metrics["num_samples"])
stats = {"training_acc": train_acc, "training_loss": train_loss}
if self.args.enable_wandb:
wandb.log({"Train/Acc": train_acc, "round": round_idx})
wandb.log({"Train/Loss": train_loss, "round": round_idx})
mlops.log({"Train/Acc": train_acc, "round": round_idx})
mlops.log({"Train/Loss": train_loss, "round": round_idx})
logging.info(stats)
stats = {"test_acc": test_acc, "test_loss": test_loss}
if self.args.enable_wandb:
wandb.log({"Test/Acc": test_acc, "round": round_idx})
wandb.log({"Test/Loss": test_loss, "round": round_idx})
mlops.log({"Test/Acc": test_acc, "round": round_idx})
mlops.log({"Test/Loss": test_loss, "round": round_idx})
logging.info(stats)
In the 4th line of the function, why is always the zeroth client selected? This way, the testing happens on the model corresponding to the zeroth client only, but we want the average test error on the local dataset for each client, isn't it?
@shubham22124 Thank you for asking this question. However, in that line, we just get the general client state (e.g., model) from the first client. The evaluation still happens across all clients (see client_idx) as shown in line 195: https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py#L195
But line 195 just updates the dataset. Shouldn't the model be updated as well, as each client undergoes local training and has a different model than the model received from the server?
@shubham22124 So basically, lines 193-198 is where the global model is being evaluated against the local dataset of each client, so every client's model is the same, hence the client = self.client_list[0].
In other words, the evaluation of the global model is rotated to each client's dataset.