FedModule
FedModule copied to clipboard
联邦学习模块化框架,支持各类FL。A universal federated learning framework, free to switch thread and process modes
Async-FL
This document is also available in: 中文 | English
keywords:
federated-learning
,asynchronous
,synchronous
,semi-asynchronous
,personalized
Table of Contents
-
Brief
-
Git Branch Description
-
Requirements
-
Getting Started
- Experiments
- Docker
-
Features
-
Project Directory
-
Framework
-
Code Explanations
- Receiver Class
- Checker Class
-
Configuration
- Asynchronous Configuration
- Synchronous Configuration
- Semi-aynchronous Configuration
- Parameter explanation
-
Adding New Algorithm
- Adding Loss Function
-
Staleness Settings
-
Data Distribution Settings
- iid
- dirichlet non-iid
-
customize non-iid
- label distribution
- data distribution
-
Adding New Client Class
-
Multi-GPU
-
Existing Bugs
-
Contributors
-
Contact Us
Brief
Use thread/process simulation clients for federated learning. The modular design makes the project highly scalable and supports various current mainstream federated learning models: synchronous, asynchronous, semi-asynchronous, personalized, etc. Thread mode provides users with a good debugging environment, while process mode improves experimental efficiency and optimizes experimental data.
wandb synchronizes experimental data to the cloud with one click, so there is no need to worry about data loss.
Git Branch Description
The master branch is the main branch with the latest code, but some of the commits are dirty commits and not guaranteed to run properly. It is recommended to use tagged versions for better stability.
The checkout branch retains the functionality of adding clients to the system during the training process, which has been removed in the main branch. The checkout branch is not actively maintained and only supports synchronous and asynchronous FL.
Requirements
python3.8 + pytorch + macos
It has been validated on Linux.
It supports single GPU and Multi-GPU.
Getting Started
Experiments
You can run python main.py
(the main file in the fl directory) directly. The program will automatically read the config.json
file in the root directory and store the results in the specified path under results
, along with the configuration file.
You can also specify the configuration file by python main.py ../../config.json
. Please note that the path of config.json
is relative to the main.py
.
The config
folder in the root directory provides some algorithm configuration files proposed in papers. The following algorithm implementations are currently available:
FedAvg
FedAsync
FedProx
FedAT
FedLC
FedDL
M-Step AsyncFL
FedAdam
Docker
Now you can directly pull and run a Docker image, the command is as follows:
docker pull desperadoccy/async-fl
docker run -it async-fl config/FedAvg-config.json
Similarly, it supports passing a config file path as a parameter. You can also build the Docker image yourself.
cd docker
docker build -t async-fl .
docker run -it async-fl config/FedAvg-config.json
Features
- [x] Asynchronous Federated Learning
- [x] Support model and dataset replacement
- [x] Support scheduling algorithm replacement
- [x] Support aggregation algorithm replacement
- [x] Support loss function replacement
- [x] Support client replacement
- [x] Synchronous federated learning
- [x] Semi-asynchronous federated learning
- [x] Provide test loss information
- [x] Custom label heterogeneity
- [ ] Custom data heterogeneity
- [x] Support Dirichlet distribution
- [x] wandb visualization
- [ ] Support for leaf-related datasets
- [x] Support for multiple GPUs
- [x] Docker deployment
- [x] Process thread switching
Project Directory
Project Directory
.
├── config Common algorithm configuration files
│ ├── FedAT-config.json
│ ├── FedAsync-config.json
│ ├── FedAvg-config.json
│ ├── FedDL-config.json
│ ├── FedLC-config.json
│ ├── FedProx-config.json
│ ├── MSTEPAsync-config.json
│ ├── config.json
│ └── model_config
│ ├── CIFAR10-config.json
│ ├── ResNet18-config.json
│ └── ResNet50-config.json
├── config.json
├── config_semi.json
├── config_semi_test.json
├── config_sync.json
├── config_sync_test.json
├── config_test.json
├── doc
│ ├── params.docx
│ ├── pic
│ │ ├── fedsemi.png
│ │ ├── framework.png
│ │ └── header.png
│ ├── readme-zh.md
│ └── 参数.docx
├── docker
│ └── Dockerfile
├── license
├── readme.md
├── requirements.txt
└── src
├── checker checker implementation
│ ├── AllChecker.py
│ ├── CheckerCaller.py
│ ├── SyncChecker.py
│ └── __init__.py
├── client client implementation
│ ├── ActiveClient.py
│ ├── Client.py
│ ├── DLClient.py
│ ├── NormalClient.py
│ ├── ProxClient.py
│ ├── SemiClient.py
│ ├── TestClient.py
│ └── __init__.py
├── clientmanager client manager implementation
│ ├── BaseClientManager.py
│ ├── NormalClientManager.py
│ └── __init__.py
├── compressor compressor algorithm class
│ ├── QSGD.py
│ └── __init__.py
├── data
├── dataset
│ ├── CIFAR10.py
│ ├── FashionMNIST.py
│ ├── MNIST.py
│ └── __init__.py
├── exception
│ ├── ClientSumError.py
│ └── __init__.py
├── fl wandb running directory
│ ├── __init__.py
│ ├── main.py
│ └── wandb
├── group group algorithm class
│ ├── AbstractGroup.py
│ ├── DelayGroup.py
│ ├── GroupCaller.py
│ ├── OneGroup.py
│ └── __init__.py
├── groupmanager group manager implementation
│ ├── BaseGroupManager.py
│ ├── NormalGroupManager.py
│ └── __init__.py
├── loss loss algorithm class
│ ├── FedLC.py
│ ├── LossFactory.py
│ └── __init__.py
├── model
│ ├── CNN.py
│ └── __init__.py
├── numgenerator num generator algorithm class
│ ├── AbstractNumGenerator.py
│ ├── NumGeneratorFactory.py
│ ├── StaticNumGenerator.py
│ └── __init__.py
├── queuemanager queuemanager implementation
│ ├── AbstractQueueManager.py
│ ├── BaseQueueManger.py
│ ├── QueueListManager.py
│ ├── SingleQueueManager.py
│ └── __init__.py
├── receiver receiver implementation
│ ├── MultiQueueReceiver.py
│ ├── NoneReceiver.py
│ ├── NormalReceiver.py
│ ├── ReceiverCaller.py
│ └── __init__.py
├── results
├── schedule scheduling algorithm class
│ ├── AbstractSchedule.py
│ ├── FullSchedule.py
│ ├── NoSchedule.py
│ ├── RandomSchedule.py
│ ├── RoundRobin.py
│ ├── ScheduleCaller.py
│ └── __init__.py
├── scheduler scheduler implementation
│ ├── AsyncScheduler.py
│ ├── BaseScheduler.py
│ ├── SemiAsyncScheduler.py
│ ├── SyncScheduler.py
│ └── __init__.py
├── server server implementation
│ ├── AsyncServer.py
│ ├── BaseServer.py
│ ├── SemiAsyncServer.py
│ ├── SyncServer.py
│ └── __init__.py
├── test for test
│ ├── __init__.py
│ ├── test.ipynb
│ └── test.py
├── update update algorithm class
│ ├── AbstractUpdate.py
│ ├── AsyncAvg.py
│ ├── FedAT.py
│ ├── FedAsync.py
│ ├── FedAvg.py
│ ├── FedDL.py
│ ├── StepAsyncAvg.py
│ ├── UpdateCaller.py
│ └── __init__.py
├── updater updater implementation
│ ├── AsyncUpdater.py
│ ├── BaseUpdater.py
│ ├── SemiAsyncUpdater.py
│ ├── SyncUpdater.py
│ └── __init__.py
└── utils
├── ConfigManager.py
├── GlobalVarGetter.py
├── IID.py
├── JsonTool.py
├── ModelTraining.py
├── ModuleFindTool.py
├── Plot.py
├── ProcessTool.py
├── Queue.py
├── Random.py
├── Time.py
├── Tools.py
└── __init__.py
The "Time" file under the "utils" package is an implementation of a multi-threaded time acquisition class, and the "Queue" file is an implementation of related functionalities for the "queue" module, as some functionalities of the "queue" module are not yet implemented on macOS.
Framework
Code Explanations
Receiver Class
The receiver in synchronous and semi-asynchronous federated learning is used to check whether the updates received during the current global iteration meet the conditions set, such as whether all designated clients have uploaded their updates. If the conditions are met, the updater process will be triggered to perform global aggregation.
Checker Class
In synchronous and semi-asynchronous federated learning, after a client completes its training, it will upload its weights to the uploader class, which will determine whether the update meets the upload criteria based on its own logic, and decide whether to accept or discard the update.
Configuration
Configuration
async mdoe example
sync mdoe example
semi-async mdoe example
Parameter explanation
Parameter explanation
parameters |
type |
explanations |
||||
wandb |
enabled |
bool |
whether to enable wandb |
|||
project |
string |
project name |
||||
name |
string |
the name of this run |
||||
global |
use_file_system |
bool |
whether to enable the file system as the torch multi-thread sharing strategy |
|||
multi_gpu |
bool |
whether to enable multi-GPU, detailed explanation |
||||
experiment |
string |
the name of this run |
||||
stale |
explanation |
|||||
dataset |
path |
string |
the path of the dataset |
|||
params |
dict |
required parameters |
||||
iid |
explanation |
|||||
client_num |
int |
client num |
||||
server |
path |
string |
the path of server |
|||
epochs |
int |
global epoch |
||||
model |
path |
string |
the path of the model |
|||
params |
dict |
required parameters |
||||
scheduler |
path |
string |
the path of the scheduler |
|||
schedule |
path |
string |
the path of the schedule |
|||
params |
dict |
required parameters |
||||
other_params |
* |
other parameters |
||||
updater |
path |
string |
the path of the updater |
|||
update |
path |
string |
the path of the update |
|||
params |
dict |
required parameters |
||||
loss |
explanation |
|||||
num_generator |
explanation |
|||||
group |
path |
string |
the path of the updater |
|||
params |
dict |
required parameters |
||||
client_manager |
path |
string |
the path of the client manager |
|||
group_manager |
path |
string |
the path of the group manager |
|||
group_method |
path |
string |
the path of the group method |
|||
params |
dict |
required parameters |
||||
queue_manager |
path |
string |
the path of the queue manager |
|||
receiver |
path |
string |
the path of the receiver |
|||
params |
dict |
required parameters |
||||
checker |
path |
string |
the path of the checker |
|||
params |
dict |
required parameters |
||||
client |
path |
string |
the path of the client |
|||
epochs |
int |
local epoch |
||||
batch_size |
int |
batch |
||||
model |
path |
string |
the path of the model |
|||
params |
dict |
required parameters |
||||
loss |
explanation |
|||||
mu |
float |
proximal term’s coefficient |
||||
optimizer |
path |
string |
the path of the optimizer |
|||
params |
dict |
required parameters |
||||
other_params |
* |
other parameters |
||||
Adding New Algorithm
To allow clients/servers to call your own algorithms or implementation classes (note: all algorithm implementations must be in class form), the following steps are required:
- Add your own implementation to the corresponding location (dataset, model, schedule, update, client, loss)
- Import the class in the
__init__.py
file of the corresponding package, for examplefrom model import CNN
- Declare in the configuration file,
model_path
corresponds to the path where the new algorithm is located. -
checker
,group
,receiver
,schedule
, andupdate
modules need to be supplemented with invocation methods in theCaller
class. -
loss
andnumgenerator
modules need to be supplemented with invocation methods in thefactory
class.
In addition, parameters that the algorithm needs to use can be declared in the params
configuration item.
Now the model
, optim
, and loss
modules support the introduction of built-in implementation classes such as torch
, for example:
"model": {
"path": "torchvision.models.resnet18",
"params": {
"pretrained": true,
"num_classes": 10
}
}
Adding Loss Function
The loss function is now generated and created by the LossFactory
class. You can choose to use built-in algorithms from Torch
or implement your own.
The loss configuration supports three settings. The first option is using a string format commonly used in the configuration file:
"loss": "torch.nn.functional.cross_entropy"
In this case, the program will directly generate a loss function using the functional
approach.
The second option is to generate an object-based
loss:
"loss": {
"path": "loss.myloss.MyLoss",
"params": {}
}
Here, you specify the path to your custom loss class and provide any necessary parameters in the params field.
The third option is to generate a loss based on the type:
"loss": {
"type": "func",
"path": "loss.myloss.MyLoss",
"params": {}
}
With this option, you also provide the type field as "func", and the rest of the process is similar to the object-based approach.
Staleness Settings
stale
has three settings, one of which is mentioned in the above configuration file.
"stale": {
"step": 5,
"shuffle": true,
"list": [10, 10, 10, 5, 5, 5, 5]
}
The program will generate a string of random integers based on the provided step
and list
. For example, in the code above, the program will generate 10 zeros, 10 (0, 5), and 10 [5, 10), and shuffle them if shuffle is set to true. Finally, the random string is assigned to each client, and the client sleeps according to the corresponding number of seconds after each round of training. When storing the JSON file to the experimental results, this setting will be automatically converted to the third setting.
The second option is to set it to false, in which case the program will set the delay for each client to 0.
"stale": false
The third option is a list of random integers, and the program will directly assign the delay settings from the list to the clients.
"stale": [1, 2, 3, 1, 4]
Data Distribution Settings
iid
When iid
is set to true (in fact, it is also the default when set to false), the data will be distributed to each client in an identical and independent way (iid).
"iid": true
dirichlet non-iid
When customize
in iid is set to false or not set, the data will be distributed to each client in a Dirichlet distribution.
Beta is the parameter of the Dirichlet distribution.
"iid": {
"customize": false,
"beta": 0.5
}
or
"iid": {
"beta": 0.5
}
customize non-iid
Customized non-iid settings are divided into two parts, one is for label non-iid setting and the other is for data quantity non-iid setting. Currently, only random generation is provided for data quantity, and personalized settings will be introduced in future versions.
When enabling the customized setting, you need to set customize
to true and set label
and data
separately.
"iid": {
"customize": true
}
label distribution
Label setting is similar to staleness settings and supports three modes. The first one is mentioned in the configuration file.
"label": {
"step": 1,
"list": [10, 10, 30]
}
The above configuration will generate 10 clients with 1 label data, 10 clients with 2 label data, and 30 clients with 3 label data.
If step
is set to 2, the program will generate 10 clients with 1 label data, 10 clients with 3 label data, and 30 clients with 5 label data.
The second option is a two-dimensional array of random numbers, and the program will assign the array directly to the clients.
"label": {
"0": [1, 2, 3, 8],
"1": [2, 4],
"2": [4, 7],
"3": [0, 2, 3, 6, 9],
"4": [5]
}
The third option is a one-dimensional array, which represents the number of labels each client has, and the length of the array should be the same as the number of clients.
"label": {
"list": [4, 5, 10, 1, 2, 3, 4]
}
The above configuration sets the number of label data for each client: client 0 has 4 label data, client 1 has 5 label data, and so on.
Currently, there are two randomization methods for generating label non-iid data, one is pure randomization, which may lead to all clients missing one label, resulting in a decrease in accuracy (although the probability is extremely low). The other method uses shuffle algorithm to ensure that each label is selected, but it also leads to the inability to generate data with uneven label distributions. The shuffle algorithm is controlled by the shuffle parameter, as shown below:
"label": {
"shuffle": true,
"list": [4, 5, 10, 1, 2, 3, 4]
}
data distribution
The data setting is relatively simple, currently there are two methods, one of which is empty.
"data": {}
That is, no non-iid setting is performed on the data quantity.
The second method is mentioned in the configuration file.
"data": {
"max": 500,
"min": 400
}
That is, the data quantity for each client will be randomly distributed between 400 and 500, and will be evenly distributed among the labels by default.
The data quantity distribution is still relatively simple at this point, and will be gradually improved in the future.
Adding New Client Class
Currently, client replacement needs to inherit from AsyncClient
or SyncClient
, and the new parameters are passed into the class through the client
configuration item.
Multi-GPU
The multi-GPU feature of this project is not about multi-GPU parallel computing. Each client is still trained on a single GPU, but macroscopically, the clients run on multiple GPUs. That is, the training tasks of each client will be evenly distributed to the GPUs visible to the program
. The GPU bound to each client is specified at initialization and is not specified on each round of training. Therefore, it is still possible to have a serious imbalance in GPU load.
This feature is controlled by the multi_gpu
switch in the global settings.
Existing Bugs
Currently, there is a core issue in the framework that the communication between clients and servers is implemented using the multiprocessing
queues. However, when a CUDA tensor is received by the queue and retrieved by other threads, it can cause a memory leak and may cause the program to crash.
This bug is caused by PyTorch and the multiprocessing queue, and the current solution is to upload non-CUDA tensors to the queue and convert them to CUDA tensors during aggregation. Therefore, when adding aggregation algorithms, the following code will be needed:
updated_parameters = {}
for key, var in client_weights.items():
updated_parameters[key] = var.clone()
if torch.cuda.is_available():
updated_parameters[key] = updated_parameters[key].cuda()
Contributors
Desperadoccy |
Jzj007 |
Cauchy |
Contact us
We created a QQ group to discuss the asyncFL framework and FL, welcome everyone to join~~
Here is the group number:
895896624
QQ: 527707607
email: [email protected]
Welcome to provide suggestions for the project~
if you'd like contribute to this project, please contact us.