ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

AttributeError: _old_init is not found in the LitEma, please make sure that you have imported LitEma before entering the context.

Open guohe369 opened this issue 2 years ago • 1 comments

🐛 Describe the bug

When I use ColossalAI strategy to train stable diffusion v2, errors are as follow.

please init model in the ColoInitContext File "/data/guohe/stable-diffusion-v2/scripts/train/main.py", line 772, in trainer.fit(model, data) AssertionError: please init model in the ColoInitContext

After I add a line before model_instantiate,

    with ColoInitContext(device=torch.device('cuda'), dtype=torch.half):
        model = instantiate_from_config(config.model)

another error happens:

_old_init is not found in the LitEma, please make sure that you have imported LitEma before entering the context. File "/data/guohe/stable-diffusion-v2/scripts/train/main.py", line 551, in model = instantiate_from_config(config.model) AttributeError: _old_init is not found in the LitEma, please make sure that you have imported LitEma before entering the context.

After I import LitEma in main.py, the same kind of error push me to import SILU...

Please help me, thanks!

Environment

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
absl-py 1.4.0 pypi_0 pypi aiohttp 3.8.4 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi albumentations 0.4.3 pypi_0 pypi altair 4.2.2 pypi_0 pypi antlr4-python3-runtime 4.8 pypi_0 pypi async-timeout 4.0.2 pypi_0 pypi attrs 22.2.0 pypi_0 pypi autopep8 1.6.0 pyhd3eb1b0_1
awscli 1.27.84 pypi_0 pypi backports-zoneinfo 0.2.1 pypi_0 pypi bcrypt 4.0.1 pypi_0 pypi blas 1.0 mkl
blessings 1.7 py38h06a4308_1002
blinker 1.5 pypi_0 pypi botocore 1.29.84 pypi_0 pypi brotlipy 0.7.0 py38h27cfd23_1003
bzip2 1.0.8 h7b6447c_0
ca-certificates 2023.01.10 h06a4308_0
cachetools 5.3.0 pypi_0 pypi certifi 2022.12.7 py38h06a4308_0
cffi 1.15.1 py38h5eee18b_3
cfgv 3.3.1 pypi_0 pypi charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.1.3 pypi_0 pypi clip 1.0 dev_0 cmake 3.25.2 pypi_0 pypi colorama 0.4.4 pypi_0 pypi coloredlogs 15.0.1 pypi_0 pypi colossalai 0.2.5 pypi_0 pypi contexttimer 0.3.3 pypi_0 pypi cryptography 39.0.1 py38h9ce1e76_0
cudatoolkit 11.3.1 h2bc3f7f_2
decorator 5.1.1 pypi_0 pypi diffusers 0.14.0 pypi_0 pypi distlib 0.3.6 pypi_0 pypi docutils 0.16 pypi_0 pypi einops 0.3.0 pypi_0 pypi entrypoints 0.4 pypi_0 pypi fabric 3.0.0 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch filelock 3.9.0 pypi_0 pypi fire 0.5.0 pypi_0 pypi flatbuffers 23.3.3 pypi_0 pypi flit-core 3.6.0 pyhd3eb1b0_0
freetype 2.12.1 h4a9f257_0
frozenlist 1.3.3 pypi_0 pypi fsspec 2023.1.0 pypi_0 pypi ftfy 6.1.1 pypi_0 pypi future 0.18.3 py38h06a4308_0
giflib 5.2.1 h5eee18b_3
gitdb 4.0.10 pypi_0 pypi gitpython 3.1.31 pypi_0 pypi gmp 6.2.1 h295c915_3
gnutls 3.6.15 he1e5248_0
google-auth 2.16.2 pypi_0 pypi google-auth-oauthlib 0.4.6 pypi_0 pypi gpustat 0.6.0 pyhd3eb1b0_1
grpcio 1.51.3 pypi_0 pypi huggingface-hub 0.12.1 pypi_0 pypi humanfriendly 10.0 pypi_0 pypi identify 2.5.18 pypi_0 pypi idna 3.4 py38h06a4308_0
imageio 2.26.0 pypi_0 pypi imageio-ffmpeg 0.4.2 pypi_0 pypi imgaug 0.2.6 pypi_0 pypi importlib-metadata 6.0.0 pypi_0 pypi importlib-resources 5.12.0 pypi_0 pypi intel-extension-for-pytorch 1.13.100 pypi_0 pypi intel-openmp 2023.0.0 pypi_0 pypi invisible-watermark 0.1.5 pypi_0 pypi invoke 2.0.0 pypi_0 pypi jinja2 3.1.2 pypi_0 pypi jmespath 1.0.1 pypi_0 pypi jpeg 9e h5eee18b_1
jsonschema 4.17.3 pypi_0 pypi lame 3.100 h7b6447c_0
lazy-loader 0.1 pypi_0 pypi lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libdeflate 1.17 h5eee18b_0
libffi 3.4.2 h6a678d5_6
libgcc-ng 11.2.0 h1234567_1
libgfortran-ng 11.2.0 h00389a5_1
libgfortran5 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libiconv 1.16 h7f8727e_2
libidn2 2.3.2 h7f8727e_0
libpng 1.6.39 h5eee18b_0
libprotobuf 3.20.3 he621ea3_0
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.16.0 h27cfd23_0
libtiff 4.5.0 h6a678d5_2
libunistring 0.9.10 h27cfd23_0
libuv 1.44.2 h5eee18b_0
libwebp 1.2.4 h11a3e52_1
libwebp-base 1.2.4 h5eee18b_1
lightning-utilities 0.7.1 pypi_0 pypi lit 16.0.0rc3 pypi_0 pypi lz4-c 1.9.4 h6a678d5_0
markdown 3.4.1 pypi_0 pypi markdown-it-py 2.2.0 pypi_0 pypi markupsafe 2.1.2 pypi_0 pypi mdurl 0.1.2 pypi_0 pypi mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py38h7f8727e_0
mkl_fft 1.3.1 py38hd3c417c_0
mkl_random 1.2.2 py38h51133e4_0
mpmath 1.2.1 pypi_0 pypi multidict 6.0.4 pypi_0 pypi mypy-extensions 1.0.0 pypi_0 pypi ncurses 6.4 h6a678d5_0
nettle 3.7.3 hbbd107a_1
networkx 3.0 pypi_0 pypi ninja 1.11.1 pypi_0 pypi ninja-base 1.10.2 hd09550d_5
nodeenv 1.7.0 pypi_0 pypi numpy 1.24.2 pypi_0 pypi numpy-base 1.23.5 py38h31eccc5_0
nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi nvidia-ml 7.352.0 pyhd3eb1b0_0
oauthlib 3.2.2 pypi_0 pypi omegaconf 2.1.1 pypi_0 pypi onnx 1.13.1 pypi_0 pypi onnxruntime 1.14.1 pypi_0 pypi open-clip-torch 2.15.0 pypi_0 pypi opencv-python 4.1.2.30 pypi_0 pypi opencv-python-headless 4.7.0.72 pypi_0 pypi openh264 2.1.1 h4ff587b_0
openssl 1.1.1t h7f8727e_0
packaging 23.0 pypi_0 pypi pandas 1.5.3 pypi_0 pypi paramiko 3.0.0 pypi_0 pypi pillow 6.2.2 pypi_0 pypi pip 20.3.3 py38h06a4308_0
pkgutil-resolve-name 1.3.10 pypi_0 pypi platformdirs 3.1.0 pypi_0 pypi pre-commit 3.1.1 pypi_0 pypi protobuf 3.20.3 pypi_0 pypi psutil 5.9.0 py38h5eee18b_0
pudb 2019.2 pypi_0 pypi pyarrow 11.0.0 pypi_0 pypi pyasn1 0.4.8 pypi_0 pypi pyasn1-modules 0.2.8 pypi_0 pypi pycodestyle 2.10.0 py38h06a4308_0
pycparser 2.21 pyhd3eb1b0_0
pydeck 0.8.0 pypi_0 pypi pydeprecate 0.3.1 pypi_0 pypi pygments 2.14.0 pypi_0 pypi pympler 1.0.1 pypi_0 pypi pynacl 1.5.0 pypi_0 pypi pyopenssl 23.0.0 py38h06a4308_0
pyre-extensions 0.0.23 pypi_0 pypi pyrsistent 0.19.3 pypi_0 pypi pysocks 1.7.1 py38h06a4308_0
python 3.8.16 h7a1cb2a_3
python-dateutil 2.8.2 pypi_0 pypi pytorch-lightning 1.9.2 pypi_0 pypi pytorch-mutex 1.0 cuda pytorch pytz 2022.7.1 pypi_0 pypi pytz-deprecation-shim 0.1.0.post0 pypi_0 pypi pywavelets 1.4.1 pypi_0 pypi pyyaml 5.4.1 pypi_0 pypi readline 8.2 h5eee18b_0
regex 2022.10.31 pypi_0 pypi requests 2.28.1 py38h06a4308_0
requests-oauthlib 1.3.1 pypi_0 pypi rich 13.3.1 pypi_0 pypi rsa 4.7.2 pypi_0 pypi s3transfer 0.6.0 pypi_0 pypi scikit-image 0.20.0 pypi_0 pypi scipy 1.9.1 pypi_0 pypi semver 2.13.0 pypi_0 pypi sentencepiece 0.1.97 pypi_0 pypi setuptools 65.6.3 py38h06a4308_0
six 1.16.0 pyhd3eb1b0_1
smmap 5.0.0 pypi_0 pypi sqlite 3.40.1 h5082296_0
stable-diffusion 0.0.1 dev_0 streamlit 1.19.0 pypi_0 pypi sympy 1.11.1 pypi_0 pypi taming-transformers 0.0.1 dev_0 tensorboard 2.12.0 pypi_0 pypi tensorboard-data-server 0.7.0 pypi_0 pypi tensorboard-plugin-wit 1.8.1 pypi_0 pypi tensorboardx 2.6 pypi_0 pypi termcolor 2.2.0 pypi_0 pypi test-tube 0.7.5 pypi_0 pypi tifffile 2023.2.28 pypi_0 pypi timm 0.6.12 pypi_0 pypi tk 8.6.12 h1ccaba5_0
tokenizers 0.12.1 pypi_0 pypi toml 0.10.2 pyhd3eb1b0_0
toolz 0.12.0 pypi_0 pypi torch 1.10.0+cu113 pypi_0 pypi torch-fidelity 0.3.0 pypi_0 pypi torchaudio 0.12.0+cu113 pypi_0 pypi torchmetrics 0.11.3 pypi_0 pypi torchvision 0.13.0+cu113 pypi_0 pypi tornado 6.2 pypi_0 pypi tqdm 4.65.0 pypi_0 pypi transformers 4.19.2 pypi_0 pypi triton 2.0.0 pypi_0 pypi typing-extensions 4.4.0 py38h06a4308_0
typing-inspect 0.8.0 pypi_0 pypi typing_extensions 4.4.0 py38h06a4308_0
tzdata 2022.7 pypi_0 pypi tzlocal 4.2 pypi_0 pypi urllib3 1.26.14 py38h06a4308_0
urwid 2.1.2 pypi_0 pypi validators 0.20.0 pypi_0 pypi virtualenv 20.20.0 pypi_0 pypi watchdog 2.3.1 pypi_0 pypi wcwidth 0.2.6 pypi_0 pypi werkzeug 2.2.3 pypi_0 pypi wheel 0.38.4 py38h06a4308_0
xformers 0.0.17+b89a493.d20230305 dev_0 xz 5.2.10 h5eee18b_1
yaml 0.2.5 h7b6447c_0
yarl 1.8.2 pypi_0 pypi zipp 3.15.0 pypi_0 pypi zlib 1.2.13 h5eee18b_0
zstd 1.5.2 ha4553b6_0

guohe369 avatar Mar 08 '23 11:03 guohe369

I have init all modules in configure_sharded_model

def configure_sharded_model(self): if self.use_colossalai: rank_zero_info("Configure sharded model for LatentDiffusion") self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) count_params(self.model, verbose=True) if self.use_ema: self.model_ema = LitEma(self.model)

        if self.ckpt is not None:
            self.init_from_ckpt(
                self.ckpt, ignore_keys=self.ignore_keys, only_model=self.load_only_unet)
            if self.reset_ema:
                assert self.use_ema
                rank_zero_info(
                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
                self.model_ema = LitEma(self.model)

        self.register_schedule(given_betas=self.given_betas,
                            beta_schedule=self.beta_schedule,
                            timesteps=self.timesteps,
                            linear_start=self.linear_start,
                            linear_end=self.linear_end,
                            cosine_s=self.cosine_s)

        self.logvar = torch.full(
            fill_value=self.logvar_init, size=(self.num_timesteps,))
        if self.learn_logvar:
            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
        if self.ucg_training:
            self.ucg_prng = np.random.RandomState()
        self.learnable_vector = nn.Parameter(torch.randn((1,1,1024)), requires_grad=True)
        self.instantiate_first_stage(self.first_stage_config)
        self.instantiate_cond_stage(self.cond_stage_config)
        if self.ckpt is not None:
            self.init_from_ckpt(self.ckpt, self.ignore_keys)
            self.restarted_from_ckpt = True
            if self.reset_ema:
                assert self.use_ema
                rank_zero_info(
                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
                self.model_ema = LitEma(self.model)
    else:
        return super().configure_sharded_model()

guohe369 avatar Mar 08 '23 11:03 guohe369

Hi @guohe369 We have updated a lot. Please check the latest code. This issue was closed due to inactivity. Thanks.

binmakeswell avatar Apr 27 '23 10:04 binmakeswell