LWM icon indicating copy to clipboard operation
LWM copied to clipboard

bash run_vision_chat.sh -- cause flax.errors.ScopeParamNotFoundError: Could not find parameter named "embedding" in scope "/transformer/wte"

Open jackyin68 opened this issue 1 year ago • 6 comments

While run the command of "bash scripts/run_vision_chat.sh". Error happended .How to fix it.

(lwm) llm@llm-PowerEdge-R730xd:~/projects/LWM-main$ bash scripts/run_vision_chat.sh I0221 14:02:43.257625 139932541391232 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA I0221 14:02:43.260045 139932541391232 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory 100%|██████████| 1/1 [00:05<00:00, 5.59s/it] Traceback (most recent call last): File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 254, in run(main) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 250, in main output = sampler(prompts, FLAGS.max_n_frames)[0] File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 230, in call output, self.sharded_rng = self._forward_generate( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params return common_infer_params(pjit_info_args, *args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun ans = call(fun, *args) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper return func(*args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 206, in fn output = self.model.generate( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 429, in generate return self._sample( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 733, in _sample state = sample_search_body_fn(state) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 704, in sample_search_body_fn model_outputs = model(state.running_token, params=params, **state.model_kwargs) File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 232, in call outputs = self.module.apply( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1511, in apply return apply( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 934, in wrapper y = fn(root, *args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 2082, in scope_fn return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 401, in call outputs = self.transformer( File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 313, in call input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids)) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 836, in _call_wrapped_method self._try_setup() File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1094, in _try_setup self.setup() File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 771, in setup self.embedding = self.param('embedding', File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1263, in param v = self.scope.param(name, init_fn, *init_args, unbox=unbox) File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 842, in param raise errors.ScopeParamNotFoundError(name, self.path_text) flax.errors.ScopeParamNotFoundError: Could not find parameter named "embedding" in scope "/transformer/wte". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)

jackyin68 avatar Feb 21 '24 06:02 jackyin68

Who can give me some advice?

jackyin68 avatar Feb 21 '24 11:02 jackyin68

Can you paste your run_vision_chat.sh script, as well as your jax/flax versions?

wilson1yan avatar Feb 21 '24 20:02 wilson1yan

Thank you in advance. Related info see as belows.

run_vision_chat.sh

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export llama_tokenizer_path="LWM-Chat-1M-Jax/tokenizer.model" export vqgan_checkpoint="LWM-Chat-1M-Jax/vqgan" export lwm_checkpoint="LWM-Chat-1M-Jax/params" export input_file="demo.jpg"

python3 -u -m lwm.vision_chat
--prompt="What is the image about?"
--input_file="$input_file"
--vqgan_checkpoint="$vqgan_checkpoint"
--dtype='fp32'
--load_llama_config='7b'
--max_n_frames=8
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)"
--load_checkpoint="params::$lwm_checkpoint"
--tokenizer.vocab_file="$llama_tokenizer_path"
2>&1 | tee ~/output.log read

pip list

Package Version


absl-py 2.1.0 aiohttp 3.9.3 aiosignal 1.3.1 appdirs 1.4.4 asttokens 2.4.1 async-timeout 4.0.3 attrs 23.2.0 build 1.0.3 cachetools 5.3.2 certifi 2024.2.2 charset-normalizer 3.3.2 chex 0.1.82 click 8.1.7 cloudpickle 3.0.0 contextlib2 21.6.0 datasets 2.13.0 decorator 5.1.1 decord 0.6.0 dill 0.3.6 docker-pycreds 0.4.0 einops 0.7.0 etils 1.7.0 exceptiongroup 1.2.0 executing 2.0.1 filelock 3.13.1 flax 0.7.0 frozenlist 1.4.1 fsspec 2024.2.0 gcsfs 2024.2.0 gitdb 4.0.11 GitPython 3.1.42 google-api-core 2.17.1 google-auth 2.28.0 google-auth-oauthlib 1.2.0 google-cloud-core 2.4.1 google-cloud-storage 2.14.0 google-crc32c 1.5.0 google-resumable-media 2.7.0 googleapis-common-protos 1.62.0 huggingface-hub 0.20.3 idna 3.6 imageio 2.34.0 imageio-ffmpeg 0.4.9 importlib-resources 6.1.1 ipdb 0.13.13 ipython 8.21.0 jax 0.4.23 jaxlib 0.4.23+cuda12.cudnn89 jedi 0.19.1 markdown-it-py 3.0.0 matplotlib-inline 0.1.6 mdurl 0.1.2 ml-collections 0.1.1 ml-dtypes 0.3.2 msgpack 1.0.7 multidict 6.0.5 multiprocess 0.70.14 nest-asyncio 1.6.0 numpy 1.26.4 nvidia-cublas-cu12 12.3.4.1 nvidia-cuda-cupti-cu12 12.3.101 nvidia-cuda-nvcc-cu12 12.3.107 nvidia-cuda-nvrtc-cu12 12.3.107 nvidia-cuda-runtime-cu12 12.3.101 nvidia-cudnn-cu12 8.9.7.29 nvidia-cufft-cu12 11.0.12.1 nvidia-cusolver-cu12 11.5.4.101 nvidia-cusparse-cu12 12.2.0.103 nvidia-nccl-cu12 2.19.3 nvidia-nvjitlink-cu12 12.3.101 oauthlib 3.2.2 opt-einsum 3.3.0 optax 0.1.7 orbax-checkpoint 0.5.3 packaging 23.2 pandas 2.2.0 parso 0.8.3 pexpect 4.9.0 pillow 10.2.0 pip 23.3.1 prompt-toolkit 3.0.43 protobuf 4.25.3 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 15.0.0 pyasn1 0.5.1 pyasn1-modules 0.3.0 Pygments 2.17.2 pyproject_hooks 1.0.0 python-dateutil 2.8.2 pytz 2024.1 PyYAML 6.0.1 regex 2023.12.25 requests 2.31.0 requests-oauthlib 1.3.1 rich 13.7.0 rsa 4.9 scipy 1.12.0 sentencepiece 0.2.0 sentry-sdk 1.40.5 setproctitle 1.3.3 setuptools 68.2.2 six 1.16.0 smmap 5.0.1 stack-data 0.6.3 tensorstore 0.1.53 tiktoken 0.6.0 tokenizers 0.13.3 tomli 2.0.1 toolz 0.12.1 tqdm 4.66.2 traitlets 5.14.1 transformers 4.29.2 tux 0.0.2 typing_extensions 4.9.0 tzdata 2024.1 urllib3 2.2.1 wandb 0.16.3 wcwidth 0.2.13 wheel 0.41.2 xxhash 3.4.1 yarl 1.9.4 zipp 3.17.0

jackyin68 avatar Feb 22 '24 08:02 jackyin68

I encountered the same problem, but eventually found that the cause was the incomplete download of the model file.

hxmmxh avatar Feb 23 '24 06:02 hxmmxh

Thanks, Which model did you use? And are you ok to run "bash run_vision_chat.sh"?

jackyin68 avatar Feb 23 '24 06:02 jackyin68

Thanks, Which model did you use? And are you ok to run "bash run_vision_chat.sh"?

hello,I encountered the same problem, and eventually found out that the model was not fully uploaded to the server. Can you please check whether the model file size at your end is consistent? If the command runs normally, it should not throw an error. If you encounter any issues, please send it over again for further review.你应该是模型文件没下载完整,或者传输时候没传完整,但没看到这个问题,你配置的环境也没问题的,jax和flax库的版本都是对的

xiaoxiaoli666 avatar Mar 01 '24 07:03 xiaoxiaoli666