DeepSpeedExamples
DeepSpeedExamples copied to clipboard
ZeRO 3 not working with GAN example: IndexError: tuple index out of range
Using the GAN example and the following deepspeed config for ZeRO 3 offload, I get the following error:
{
"train_batch_size": 64,
"zero_optimization": {
"stage": 3,
"contiguous_gradients": true,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 1e7,
"stage3_param_persistence_threshold": 1e5,
"reduce_bucket_size": 1e7,
"sub_group_size": 1e9,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
}
}
}
Loading extension module utils...
Time to load utils op: 0.00021600723266601562 seconds
[2022-08-02 07:09:09,796] [INFO] [stage3.py:1834:_overflow_clean_up] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2022-08-02 07:09:09,797] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 time (ms) | backward_microstep: 261.77 | backward_inner_microstep: 249.30 | backward_allreduce_microstep: 12.28 | step_microstep: 1.40
[2022-08-02 07:09:09,797] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 time (ms) | backward: 261.77 | backward_inner: 249.32 | backward_allreduce: 12.28 | step: 1.40
[2022-08-02 07:09:09,841] [INFO] [stage3.py:1834:_overflow_clean_up] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2022-08-02 07:09:09,842] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 time (ms) | backward_microstep: 38.73 | backward_inner_microstep: 33.22 | backward_allreduce_microstep: 5.40 | step_microstep: 0.63
[2022-08-02 07:09:09,842] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 time (ms) | backward: 38.73 | backward_inner: 33.22 | backward_allreduce: 5.40 | step: 0.64
[0/1][0/3166] Loss_D: 1.8203 Loss_G: 0.6108 D(x): 0.5308 D(G(z)): 0.5933 / 0.5933
Traceback (most recent call last):
File "gan_deepspeed_train.py", line 182, in <module>
main()
File "gan_deepspeed_train.py", line 179, in main
train(args)
File "gan_deepspeed_train.py", line 125, in train
output = netD(real)
File "/home/host/.local/bin/anaconda3/envs/pytorchtut/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
result = hook(self, input)
File "/home/host/.local/bin/anaconda3/envs/pytorchtut/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
return func(*args, **kwargs)
File "/home/host/.local/bin/anaconda3/envs/pytorchtut/lib/python3.8/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 324, in _pre_forward_module_hook
self.pre_sub_module_forward_function(module)
File "/home/host/.local/bin/anaconda3/envs/pytorchtut/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/host/.local/bin/anaconda3/envs/pytorchtut/lib/python3.8/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 444, in pre_sub_module_forward_function
param_coordinator.trace_prologue(sub_module)
File "/home/host/.local/bin/anaconda3/envs/pytorchtut/lib/python3.8/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 150, in trace_prologue
if sub_module != self.__submodule_order[self.__step_id]:
IndexError: tuple index out of range
Pip freeze
absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 antlr4-python3-runtime==4.9.3 appdirs==1.4.4 asttokens==2.0.5 async-timeout==4.0.2 attrs==22.1.0 backcall==0.2.0 black==19.10b0 brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work cachetools==5.2.0 certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1625835293160/work charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1655906222726/work click @ file:///tmp/build/80754af9/click_1646038465422/work cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1652967113783/work cycler==0.11.0 decorator==5.1.1 deepspeed==0.6.7 einops==0.4.1 executing==0.9.1 fairscale==0.4.8 fonttools==4.34.4 frozenlist==1.3.0 fsspec==2022.7.1 ftfy==6.1.1 google-auth==2.9.1 google-auth-oauthlib==0.4.6 grpcio==1.48.0 hjson==3.0.2 idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1642433548627/work imageio==2.20.0 importlib-metadata==4.12.0 ipython==8.4.0 jedi==0.18.1 kiwisolver==1.4.4 Markdown==3.4.1 MarkupSafe==2.1.1 matplotlib==3.5.2 matplotlib-inline==0.1.3 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 mpi4py==3.1.3 multidict==6.0.2 mypy-extensions==0.4.3 ninja==1.10.2.3 numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1652801679809/work oauthlib==3.2.0 olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work omegaconf==2.2.2 packaging==21.3 parso==0.8.3 pathspec==0.7.0 pexpect==4.8.0 pickleshare==0.7.5 Pillow==9.2.0 prompt-toolkit==3.0.30 protobuf==3.19.4 psutil==5.9.1 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==8.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work pydantic==1.9.1 pyDeprecate==0.3.2 Pygments==2.12.0 pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1643496850550/work pyparsing==3.0.9 PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1648857275402/work python-dateutil==2.8.2 pytorch-lightning==1.6.5 PyYAML==6.0 regex==2022.7.25 requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1656534056640/work requests-oauthlib==1.3.1 rsa==4.9 six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work stack-data==0.3.0 tensorboard==2.9.1 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 toml @ file:///tmp/build/80754af9/toml_1616166611790/work torch==1.12.0 torchaudio==0.12.0 torchmetrics==0.9.3 torchvision==0.13.0 tqdm==4.64.0 traitlets==5.3.0 typed-ast @ file:///tmp/build/80754af9/typed-ast_1624953673417/work typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1656706066251/work urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1658789158161/work wcwidth==0.2.5 Werkzeug==2.2.1 yarl==1.7.2 zipp==3.8.1
Running with the default configuration in the repo, it works just fine. Running on linux with nvidia 3060 12gb.