PFLlib
PFLlib copied to clipboard
In the pathological Non-IID setting, the samples distribution on clients may be unbalanced even the `balance` is True.
First, thanks for the code, which helps me a lot. But when reading the function separate_data
in the file ./dataset/utils/dataset_utils.py, I found if balance
is used in line 67-68. Though you have not said the code can provide pathological noniid and balanced setting in the README.md, this is ambiguous. So I raise the issure here. Next, I will explain the samples distribution on clients is affected by the distribution of the initial dataset, which results that the samples distribution on clients may be unbalanced even balance
is True (balance
is the variable in the function).
Our problem is focused on relationship between the initial dataset and the sample distribution on clients, so other variables is fixed. From selected_clients = selected_clients[:int(num_clients/num_classes*class_per_client)]
in line 62 and num_per = num_all_samples / num_selected_clients
in line 66, we know the num_per
is only affected by the num_all_samples
(others are fixed by our assumption). Then by num_samples = [int(num_per) for _ in range(num_selected_clients-1)]
in line 68, we get that the num_samples
is affected by the num_all_samples
. Considering the num_all_samples=len(idx_for_each_class[i])
in line 64 (the initial distribution), we conclude that the client gets different number of samples with the same cost (one chance), which leads to the phenomenon (the samples distribution on clients is affected by the distribution of the initial dataset).
One simple example to verify my understanding.
four labels[number of samples]
: 0[20], 1[20], 2[40], 3[20]
num_clients: 10
class_per_client: 2
then we will have the results through the code with balance==True
, client [labels|number of samples]
:
0 [0,1|8]; 1 [0,1|8]; 2 [0,1|8]; 3 [0,1|8]; 4 [0,1|8];
5 [2,3|12]; 6 [2,3|12]; 7 [2,3|12]; 8 [2,3|12]; 9 [2,3|12].
We see that this is unbalanced obviously. By the way, [6332, 6333, 6044, 6045, 5631, 5632, 6091, 6092, 5899, 5901]
is gotten if running on the mnist dataset (num_clients: 10, class_per_client: 2), which is unbalanced but not very obvious since mnist's initial sample distribution is not obvious too ([label | number of samples of this label] 0 5923, 1 6742, 2 5958, 3 6131, 4 5842, 5 5421, 6 5918, 7 6265, 8 5851, 9 5949
).
At last, I think the method that partitioning the samples into shards in the paper Communication-Efficient Learning of Deep Networks from Decentralized Data may be inspiring. Thanks for your time.
Hi, @bird-two. Thank you for your valuable comments! The problem you mentioned does exist when generating the pathological Non-IID setting with the unbalanced raw dataset.
As the statistical heterogeneity in FL is caused by the Non-IID and unbalanced data, I only provide the balanced data distribution for the IID setting.
Welcome to submit PRs for the balanced distribution in non-IID settings :-)
I think the method that partitioning the samples into shards in the paper Communication-Efficient Learning of Deep Networks from Decentralized Data may be inspiring
In Communication-Efficient Learning of Deep Networks from Decentralized Data, the authors said that "...Non-IID, where we first sort the data by digit label, divide it into 200 shards of size 300, and assign each of 100 clients 2 shards. This is a pathological non-IID partition of the data, as most clients will only have examples of two digits." If the number of images that belong to label 0 is 5923, which is less than 6000, some clients may have samples of more than two labels.
If the number of images that belong to label 0 is 5923, which is less than 6000, some clients may have samples of more than two labels.
I think this is true but unavoidable. The authors of Communication-Efficient... said that "most clients will only have examples of two digits", but not all clients. And I have checked the code from https://github.com/AshwinRJ/Federated-Learning-PyTorch, which owns the most stars among the implementations of the paper (I got the code from https://paperswithcode.com/). By the function mnist_noniid
in ./src/sampling.py, they cannot make both (the data is balanced and each client has only two labels in their setting) satisfied at the same time.
If not all clients is acceptable, I write some code based on the code from your ./dataset/utils/dataset_utils.py and ./src/sampling.py, which can be plugged into your code directly. I have changed some variable name of initial code, so I list below. n_nets: the number of clients net_dataidx_map: the dataidx_map
# decide the number of samples of one shard
n_sample_per_shard = n_train // (n_nets*class_per_client)
# compute the number of shards per class
n_shard_per_class_list = np.array([0 for _ in range(num_classes)])
for i in range(num_classes):
n_shard_per_class = len(idxs_each_class[i]) // n_sample_per_shard
n_shard_per_class_list[i] = n_shard_per_class
bidx_per_class_list = [0 for _ in range(num_classes)] # the begenning index of remaining samples per class
net_dataidx_map = {i: np.array([], dtype=int) for i in range(n_nets)}
net_i = 0
for i in range(n_nets):
# get the number of shards of this class to decide how many clients to select
class_with_surplus_shards = np.where(n_shard_per_class_list > 0)[0]
if len(class_with_surplus_shards) >= class_per_client:
selected_classes = np.random.choice(class_with_surplus_shards, class_per_client, replace=False)
for j in selected_classes:
net_dataidx_map[i] = np.concatenate((net_dataidx_map[i], idxs_each_class[j][bidx_per_class_list[j]: bidx_per_class_list[j]+n_sample_per_shard]), axis=0)
n_shard_per_class_list[j] -= 1
bidx_per_class_list[j] += n_sample_per_shard
else:
net_i = i
break
# the method modified from function `mnist_noniid` in the sampling.py.
if net_i > 0:
remaining_idxs = np.concatenate([idxs_each_class[j][bidx_per_class_list[j]:] for j in range(num_classes)], axis=0)
n_shard_remaining_idxs = len(remaining_idxs) // n_sample_per_shard
distribute_order = np.random.permutation(n_shard_remaining_idxs)
for i in range(net_i, n_nets):
t = class_per_client*(i-net_i)
net_dataidx_map[i] = np.concatenate([remaining_idxs[(distribute_order[t+j])*n_sample_per_shard:
(distribute_order[t+j]+1)*n_sample_per_shard] for j in range(class_per_client)], axis=0)
The performance can be shown as the bubble charts (the size of bubble indicates the number of samples):
- clients 10, class per client 1;
- clients 50, class per client 1;
- clients 10, class per client 2;
- clients 50, class per client 2.




I also find that the most current codes of dirichlet partition method cannot generate a balanced client datasets. This may cause training harder.