Federated-Learning-in-PyTorch
Federated-Learning-in-PyTorch copied to clipboard
Handy PyTorch implementation of Federated Learning (for your painless research)
- NOTE: This repository will be updated to ver 2.0 at least in August, 2022.
Federated Averaging (FedAvg) in PyTorch 
An unofficial implementation of FederatedAveraging (or FedAvg) algorithm proposed in the paper Communication-Efficient Learning of Deep Networks from Decentralized Data in PyTorch. (implemented in Python 3.9.2.)
Implementation points
- Exactly implement the models ('2NN' and 'CNN' mentioned in the paper) to have the same number of parameters written in the paper.
- 2NN:
TwoNNclass inmodels.py; 199,210 parameters - CNN:
CNNclass inmodels.py; 1,663,370 parameters
- 2NN:
- Exactly implement the non-IID data split.
- Each client has at least two digits in case of using
MNISTdataset.
- Each client has at least two digits in case of using
- Implement multiprocessing of client update and client evaluation.
- Support TensorBoard for log tracking.
Requirements
- See
requirements.txt
Configurations
- See
config.yaml
Run
python3 main.py
Results
MNIST
- Number of clients: 100 (K = 100)
- Fraction of sampled clients: 0.1 (C = 0.1)
- Number of rounds: 500 (R = 500)
- Number of local epochs: 10 (E = 10)
- Batch size: 10 (B = 10)
- Optimizer:
torch.optim.SGD - Criterion:
torch.nn.CrossEntropyLoss - Learning rate: 0.01
- Momentum: 0.9
- Initialization: Xavier
Table 1. Final accuracy and the best accuracy
| Model | Final Accuracy(IID) (Round) | Best Accuracy(IID) (Round) | Final Accuracy(non-IID) (Round) | Best Accuracy(non-IID) (Round) |
|---|---|---|---|---|
| 2NN | 98.38% (500) | 98.45% (483) | 97.50% (500) | 97.65% (475) |
| CNN | 99.31% (500) | 99.34% (197) | 98.73% (500) | 99.28% (493) |
Table 2. Final loss and the least loss
| Model | Final Loss(IID) (Round) | Least Loss(IID) (Round) | Final Loss(non-IID) (Round) | Least Loss(non-IID) (Round) |
|---|---|---|---|---|
| 2NN | 0.09296 (500) | 0.06956 (107) | 0.09075 (500) | 0.08257 (475) |
| CNN | 0.04781 (500) | 0.02497 (86) | 0.04533 (500) | 0.02413 (366) |
Figure 1. MNIST 2NN model accuracy (IID: top / non-IID: bottom)

Figure 2. MNIST CNN model accuracy (IID: top / non-IID: bottom)

TODO
- [ ] Do CIFAR experiment (CIFAR10 dataset) & large-scale LSTM experiment (Shakespeare dataset)
- [ ] Learning rate scheduling
- [ ] More experiments with other hyperparameter settings (e.g., different combinations of B, E, K, and C)