flower
flower copied to clipboard
Facing issue with Flower Simulation with ResNet18 and MNIST dataset
Describe the bug
I was trying a example project of Flower Simulation (Flower Simulation Step by Step Pytorch - Part II). Everything went very well until I tried to change the model to resnet18 as given below:
class Net(nn.Module):
def __init__(self, num_classes: int) -> None:
super(Net,` self).__init__()
self.model = models.resnet18()
for param in self.model.parameters():
param.requires_grad = False
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes)
summary(self.model, input_size=(1, 28, 28)) # <<== THIS LINE
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
If I add summary(self.model, input_size=(1, 28, 28))
at the end of __init__()
method, everything works. But when I remove it, I get error: input_param = input_param[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0
in evaluate_fn
of server.py:
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True) # <= At this line I'm getting error
Steps/Code to Reproduce
Clone the repository from Flower Simulation Step by Step Pytorch Part-II and follow instructions to setup the environment.
Then change the model to resnet18 in model.py file as given below:
import torch
import torch.nn as nn
import torchvision.models as models
from flwr.common.parameter import ndarrays_to_parameters
from torchsummary import summary
class Net(nn.Module):
def __init__(self, num_classes: int) -> None:
super(Net, self).__init__()
self.model = models.resnet18()
for param in self.model.parameters():
param.requires_grad = False
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes)
summary(self.model, input_size=(1, 28, 28))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
Following is the list of packages installed in the conda environment:
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_gnu conda-forge
absl-py 2.1.0 pypi_0 pypi
aiohttp 3.9.3 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
antlr4-python3-runtime 4.9.3 pypi_0 pypi
asttokens 2.4.1 pyhd8ed1ab_0 conda-forge
astunparse 1.6.3 pypi_0 pypi
async-timeout 4.0.3 pypi_0 pypi
attrs 23.2.0 pypi_0 pypi
blas 1.0 mkl
brotli-python 1.0.9 py39h6a678d5_7
bzip2 1.0.8 h5eee18b_5
ca-certificates 2024.3.11 h06a4308_0
certifi 2024.2.2 pyhd8ed1ab_0 conda-forge
cffi 1.16.0 pypi_0 pypi
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.1.7 pypi_0 pypi
comm 0.2.2 pyhd8ed1ab_0 conda-forge
contourpy 1.2.0 pypi_0 pypi
cryptography 41.0.7 pypi_0 pypi
cycler 0.12.1 pypi_0 pypi
datasets 2.18.0 pypi_0 pypi
debugpy 1.6.7 py39h6a678d5_0
decorator 5.1.1 pyhd8ed1ab_0 conda-forge
dill 0.3.8 pypi_0 pypi
exceptiongroup 1.2.0 pyhd8ed1ab_2 conda-forge
executing 2.0.1 pyhd8ed1ab_0 conda-forge
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.13.3 pypi_0 pypi
flatbuffers 24.3.25 pypi_0 pypi
flwr 1.7.0 pypi_0 pypi
flwr-datasets 0.1.0 pypi_0 pypi
fonttools 4.50.0 pypi_0 pypi
freetype 2.12.1 h4a9f257_0
frozenlist 1.4.1 pypi_0 pypi
fsspec 2024.2.0 pypi_0 pypi
gast 0.5.4 pypi_0 pypi
gmp 6.2.1 h295c915_3
gnutls 3.6.15 he1e5248_0
google-pasta 0.2.0 pypi_0 pypi
grpcio 1.62.1 pypi_0 pypi
h5py 3.10.0 pypi_0 pypi
huggingface-hub 0.22.1 pypi_0 pypi
hydra-core 1.3.2 pypi_0 pypi
idna 3.4 py39h06a4308_0
importlib-metadata 7.1.0 pyha770c72_0 conda-forge
importlib-resources 6.4.0 pypi_0 pypi
importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge
intel-openmp 2023.1.0 hdb19cb5_46306
ipykernel 6.29.3 pyhd33586a_0 conda-forge
ipython 8.18.1 pyh707e725_3 conda-forge
iterators 0.0.2 pypi_0 pypi
jedi 0.19.1 pyhd8ed1ab_0 conda-forge
jpeg 9e h5eee18b_1
jsonschema 4.21.1 pypi_0 pypi
jsonschema-specifications 2023.12.1 pypi_0 pypi
jupyter_client 8.6.1 pyhd8ed1ab_0 conda-forge
jupyter_core 5.7.2 py39hf3d152e_0 conda-forge
keras 3.1.1 pypi_0 pypi
kiwisolver 1.4.5 pypi_0 pypi
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libclang 18.1.1 pypi_0 pypi
libdeflate 1.17 h5eee18b_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 13.2.0 h807b86a_5 conda-forge
libgomp 13.2.0 h807b86a_5 conda-forge
libiconv 1.16 h7f8727e_2
libidn2 2.3.4 h5eee18b_0
libpng 1.6.39 h5eee18b_0
libsodium 1.0.18 h36c2ea0_1 conda-forge
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.1 h6a678d5_0
libunistring 0.9.10 h27cfd23_0
libwebp-base 1.3.2 h5eee18b_0
lz4-c 1.9.4 h6a678d5_0
markdown 3.6 pypi_0 pypi
markdown-it-py 3.0.0 pypi_0 pypi
markupsafe 2.1.5 pypi_0 pypi
matplotlib 3.8.3 pypi_0 pypi
matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge
mdurl 0.1.2 pypi_0 pypi
mkl 2023.1.0 h213fc3f_46344
mkl-service 2.4.0 py39h5eee18b_1
mkl_fft 1.3.8 py39h5eee18b_0
mkl_random 1.2.4 py39hdb19cb5_0
ml-dtypes 0.3.2 pypi_0 pypi
msgpack 1.0.8 pypi_0 pypi
multidict 6.0.5 pypi_0 pypi
multiprocess 0.70.16 pypi_0 pypi
namex 0.0.7 pypi_0 pypi
ncurses 6.4 h6a678d5_0
nest-asyncio 1.6.0 pyhd8ed1ab_0 conda-forge
nettle 3.7.3 hbbd107a_1
numpy 1.26.4 py39h5f9d8c6_0
numpy-base 1.26.4 py39hb5e798b_0
omegaconf 2.3.0 pypi_0 pypi
openh264 2.1.1 h4ff587b_0
openjpeg 2.4.0 h3ad879b_0
openssl 3.2.1 hd590300_1 conda-forge
opt-einsum 3.3.0 pypi_0 pypi
optree 0.11.0 pypi_0 pypi
packaging 24.0 pyhd8ed1ab_0 conda-forge
pandas 2.2.1 pypi_0 pypi
parso 0.8.4 pyhd8ed1ab_0 conda-forge
pexpect 4.9.0 pyhd8ed1ab_0 conda-forge
pickleshare 0.7.5 py_1003 conda-forge
pillow 10.2.0 py39h5eee18b_0
pip 23.3.1 py39h06a4308_0
platformdirs 4.2.0 pyhd8ed1ab_0 conda-forge
prompt-toolkit 3.0.42 pyha770c72_0 conda-forge
protobuf 4.25.3 pypi_0 pypi
psutil 5.9.8 py39hd1e30aa_0 conda-forge
ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge
pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge
pyarrow 15.0.2 pypi_0 pypi
pyarrow-hotfix 0.6 pypi_0 pypi
pycparser 2.21 pypi_0 pypi
pycryptodome 3.20.0 pypi_0 pypi
pydantic 1.10.14 pypi_0 pypi
pygments 2.17.2 pyhd8ed1ab_0 conda-forge
pyparsing 3.1.2 pypi_0 pypi
pysocks 1.7.1 py39h06a4308_0
python 3.9.19 h955ad1f_0
python-dateutil 2.9.0.post0 pypi_0 pypi
python_abi 3.9 2_cp39 conda-forge
pytorch 1.13.1 py3.9_cpu_0 pytorch
pytorch-mutex 1.0 cpu pytorch
pytz 2024.1 pypi_0 pypi
pyyaml 6.0.1 pypi_0 pypi
pyzmq 25.1.2 py39h6a678d5_0
ray 2.6.3 pypi_0 pypi
readline 8.2 h5eee18b_0
referencing 0.34.0 pypi_0 pypi
requests 2.31.0 py39h06a4308_1
rich 13.7.1 pypi_0 pypi
rpds-py 0.18.0 pypi_0 pypi
scipy 1.12.0 pypi_0 pypi
setuptools 68.2.2 py39h06a4308_0
six 1.16.0 pyh6c4a22f_0 conda-forge
sqlite 3.41.2 h5eee18b_0
stack_data 0.6.2 pyhd8ed1ab_0 conda-forge
tbb 2021.8.0 hdb19cb5_0
tensorboard 2.16.2 pypi_0 pypi
tensorboard-data-server 0.7.2 pypi_0 pypi
tensorflow-io-gcs-filesystem 0.36.0 pypi_0 pypi
termcolor 2.4.0 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
torchaudio 0.13.1 py39_cpu pytorch
torchsummary 1.5.1 pypi_0 pypi
torchvision 0.14.1 py39_cpu pytorch
tornado 6.4 py39hd1e30aa_0 conda-forge
tqdm 4.66.2 pypi_0 pypi
traitlets 5.14.2 pyhd8ed1ab_0 conda-forge
typing_extensions 4.9.0 py39h06a4308_1
tzdata 2024.1 pypi_0 pypi
urllib3 2.1.0 py39h06a4308_1
wcwidth 0.2.13 pyhd8ed1ab_0 conda-forge
werkzeug 3.0.2 pypi_0 pypi
wheel 0.41.2 py39h06a4308_0
wrapt 1.16.0 pypi_0 pypi
xxhash 3.4.1 pypi_0 pypi
xz 5.4.6 h5eee18b_0
yarl 1.9.4 pypi_0 pypi
zeromq 4.3.5 h6a678d5_0
zipp 3.18.1 pypi_0 pypi
zlib 1.2.13 h5eee18b_0
zstd 1.5.5 hc292b87_0
requirement.txt file
datasets==2.18.0
flwr==1.7.0
hydra-core==1.3.2
omegaconf==2.3.0
torch==1.13.1
torchvision==0.14.1
flwr[simulation]>=1.0, <2.0
matplotlib==3.8.3
scipy==1.12.0
torchsummary==1.5.1
Expected Results
Following is the output when it runs successfully (by adding line summary(self.model, input_size=(1, 28, 28))
) :
{'history': History (loss, distributed): round 1: 6.738090056180954 round 2: 3.8934330970048903 History (loss, centralized): round 0: 366.1482033729553 round 1: 97.4027541577816 round 2: 52.76616382226348 History (metrics, centralized): {'accuracy': [(0, 0.1086), (1, 0.8021), (2, 0.8959)]}
Actual Results
When I remove line summary(self.model, input_size=(1, 28, 28))
, I get following error:
[2024-04-08 09:43:34,760][flwr][INFO] - Initializing global parameters
[2024-04-08 09:43:34,761][flwr][INFO] - Requesting initial parameters from one random client
[2024-04-08 09:43:37,337][flwr][INFO] - Received initial parameters from one random client
[2024-04-08 09:43:37,338][flwr][INFO] - Evaluating initial parameters
[2024-04-08 09:43:37,644][flwr][ERROR] - index 0 is out of bounds for dimension 0 with size 0
[2024-04-08 09:43:37,646][flwr][ERROR] - Traceback (most recent call last):
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/simulation/app.py", line 308, in start_simulation
hist = run_fl(
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/app.py", line 225, in run_fl
hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout)
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/server.py", line 92, in fit
res = self.strategy.evaluate(0, parameters=self.parameters)
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/strategy/fedavg.py", line 165, in evaluate
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
File "/root/development/machine-learning-project/server.py", line 42, in evaluate_fn
model.load_state_dict(state_dict, strict=True)
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1657, in load_state_dict
load(self, state_dict)
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1645, in load
load(child, child_state_dict, child_prefix)
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1645, in load
load(child, child_state_dict, child_prefix)
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1639, in load
module._load_from_state_dict(
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 110, in _load_from_state_dict
super(_NormBase, self)._load_from_state_dict(
File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _load_from_state_dict
input_param = input_param[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0
[2024-04-08 09:43:37,648][flwr][ERROR] - Your simulation crashed :(. This could be because of several reasons. The most common are:
> Sometimes, issues in the simulation code itself can cause crashes. It's always a good idea to double-check your code for any potential bugs or inconsistencies that might be contributing to the problem. For example:
- You might be using a class attribute in your clients that hasn't been defined.
- There could be an incorrect method call to a 3rd party library (e.g., PyTorch).
- The return types of methods in your clients/strategies might be incorrect.
> Your system couldn't fit a single VirtualClient: try lowering `client_resources`.
> All the actors in your pool crashed. This could be because:
- You clients hit an out-of-memory (OOM) error and actors couldn't recover from it. Try launching your simulation with more generous `client_resources` setting (i.e. it seems {'num_cpus': 1, 'num_gpus': 0.0} is not enough for your run). Use fewer concurrent actors.
- You were running a multi-node simulation and all worker nodes disconnected. The head node might still be alive but cannot accommodate any actor with resources: {'num_cpus': 1, 'num_gpus': 0.0}.
Take a look at the Flower simulation examples for guidance <https://flower.dev/docs/framework/how-to-run-simulations.html>.
Hi @EzyHow, have you added that summary(self.model, input_size=(1, 28, 28))
somewhere else? maybe also in the evaluation in server.py
? I wonder if torchsummary
is adding something to the state_dict
...
Flower Simulation Step by Step Pytorch Part-II
Kindly check this repository for detailed code: Testing Flower Simulation
In this repository, please go through the main.log files for three different scenarios given in output directory.
Hello,
I encountered the same issue and found a solution. I noticed the ndarrays_to_model
function in src/model_utils.py
. The relevant code is:
def ndarrays_to_model(model: torch.nn.ModuleList, params: List[np.ndarray]):
"""Set model weights from a list of NumPy ndarrays."""
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
Therefore, I changed
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
to
state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
in set_parameters function on client.py and evaluate_fn in server.py. Please also import numpy:
import numpy as np
I hope it will work for you as well.
This worked for me. How did you come to this solution? I can't find a reason for it to work.
This worked for me. How did you come to this solution? I can't find a reason for it to work.
I am not sure but see one function use torch directly and another one using numpy. Maybe because of internal functions are different.