easyFL
easyFL copied to clipboard
使用Efficientnet-b0导致qfedavg失效
感谢您的联邦框架!!非常简洁并且方便移植!!! 不过有一个问题想麻烦您回答,当我将model换成efficientnet-b0,在cifar10数据集上使用qfedavg、fedfv、fedprox时,会出现自始至终loss不变的问题,这是我设定的model from torch import nn from flgo.utils.fmodule import FModule from efficientnet_pytorch import EfficientNet
class Model(FModule): def init(self): super(Model, self).init() pretrained = True self.base_model = ( EfficientNet.from_pretrained("efficientnet-b0") if pretrained else EfficientNet.from_name("efficientnet-b0") ) # self.base_model=torchvision.models.efficientnet_v2_s(pretrained=pretrained) nftrs = self.base_model._fc.in_features # print("Number of features output by EfficientNet", nftrs) self.base_model._fc = nn.Linear(nftrs, 10)
def forward(self, x):
# Convolution layers
x = self.base_model.extract_features(x)
# Pooling and final linear layer
feature_x = self.base_model._avg_pooling(x)
if self.base_model._global_params.include_top:
x = feature_x.flatten(start_dim=1)
x = self.base_model._dropout(x)
x = self.base_model._fc(x)
return x
def init_local_module(object): pass
def init_global_module(object):
if 'Server' in object.class.name:
object.model = Model().to(object.device)
会出现这样的结果
你好,之前有人在flgo交流群中提出了同样的问题,该问题是因为qfedavg的代码中使用norm接口直接计算模型的范数,norm结构默认调用的是flgo.utiles.fmodule._model_dict_norm,而model.state_dict()中通常包含了统计量参数,使得带bn层的模型由该接口得到的范数都会非常大,若是更新过程中除以了模型范数的话,会出现这种模型更新被放缩到0的情形,我这里贴上我修复后的qfedavg代码,稍后会将该更新整合到flgo中
`"""This is a non-official implementation of 'Fair Resource Allocation in Federated Learning' (http://arxiv.org/abs/1905.10497). And this implementation refers to the official github repository https://github.com/litian96/fair_flearn """ import flgo.algorithm.fedbase as fedbase import flgo.utils.fmodule as fmodule import copy
class Server(fedbase.BasicServer): def initialize(self, *args, **kwargs): self.init_algo_para({'q': 1.0})
def iterate(self):
self.selected_clients = self.sample()
res = self.communicate(self.selected_clients)
self.model = self.model - fmodule._model_sum(res['dk']) / sum(res['hk'])
return len(self.received_clients) > 0
class Client(fedbase.BasicClient): def unpack(self, package): model = package['model'] self.global_model = copy.deepcopy(model) return model
def pack(self, model):
Fk = self.test(self.global_model, 'train')['loss'] + 1e-8
L = 1.0 / self.learning_rate
delta_wk = L * (self.global_model - model)
dk = (Fk ** self.q) * delta_wk
norm_dwk = 0.0
for p in delta_wk.parameters():
norm_dwk += (p**2).sum()
hk = self.q * (Fk ** (self.q - 1)) * (norm_dwk) + L * (Fk ** self.q)
self.global_model = None
return {'dk': dk, 'hk': hk}
`
将涉及到norm计算的地方替换成基于model.parameter计算可以修复该问题,但是由于bn和niid在联邦学习中具有天然冲突,建议直接使用不带bn或是将bn替换成gn的模型