EasyLM
EasyLM copied to clipboard
Serving errors: deprecated dependencies and structure error
When I try to serve LLaMA with v3_8
TPU as suggested in example script, there were some errors.
Environment
- TPU:
v3-8
- Software:
tpu-vm-base
Command
$ git clone https://github.com/young-geng/EasyLM
$ cd EasyLM
$ ./scripts/tpu_vm_setup.sh
$
$ python -m EasyLM.models.llama.llama_train \
--mesh_dim='1,-1,1' \
--dtype='bf16' \
--total_steps=500 \
--log_freq=50 \
--load_llama_config='1b' \
--update_llama_config='' \
--load_dataset_state='' \
--load_checkpoint='' \
--save_model_freq=100 \
--tokenizer.vocab_file='/path/to/tokenizer.model' \
--optimizer.type='adamw' \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=1e-3 \
--optimizer.adamw_optimizer.end_lr=1e-4 \
--optimizer.adamw_optimizer.lr_warmup_steps=10 \
--optimizer.adamw_optimizer.lr_decay_steps=100 \
--train_dataset.type='json' \
--train_dataset.text_processor.fields='text' \
--train_dataset.json_dataset.path='/path/to/dataset.jsonl' \
--train_dataset.json_dataset.seq_length=1024 \
--train_dataset.json_dataset.batch_size=64 \
--train_dataset.json_dataset.tokenizer_processes=1 \
--checkpointer.save_optimizer_state=True \
--checkpointer.float_dtype=bf16 \
--logger.online=False \
--logger.output_dir="~/ellama_checkpoints/" \
|& tee $HOME/output1107_wiki.txt
$
$ python -m EasyLM.models.llama.llama_serve \
--load_llama_config='1b' \
--load_checkpoint="params::/path/to/streaming_train_state" \
--tokenizer.vocab_file='/path/to/tokenizer.model' \
--mesh_dim='1,-1,1' \
--dtype='bf16' \
--input_length=1024 \
--seq_length=2048 \
--lm_server.batch_size=4 \
--lm_server.port=8888 \
--lm_server.pre_compile='all'
1. Deprecation warning
ImportError: cannot import name 'soft_unicode' from 'markupsafe' ImportError: Pandas requires version '3.0.0' or newer of 'jinja2'
These can be solved by adding 2 lines to tpu_requirements.txt
markupsafe==2.0.1
jinja2~=3.0.0
DeprecationWarning: concurrency_count has been deprecated. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via
max_threads
in launch().
I was able to solve this by deleting concurrency_count=1
in serving.py
, line 403.
According to Gradio v4.0.0 changelog, concurrency_count
is removed and can be replaced with concurrency_limit
. As I'm not exactly understanding what it supposed to do and it's set to 1 by default, I just removed it.
2. Structure error
However, when I solve deprecation errors above, this error appears:
Error Log
I1107 06:16:48.996244 140573565926464 mesh_utils.py:260] Reordering mesh to physical ring order on single-tray TPU v2/v3.
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name user_fn already exists, using user_fn_1
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_1
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_2
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
Traceback (most recent call last):
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
mlxu.run(main)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
server.run()
File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
self.loglikelihood(pre_compile_data, pre_compile_data)
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 775, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
logits = hf_model.module.apply(
File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "$HOME/.local/lib/python3.8/site-packages/flax/linen/module.py", line 1511, in apply
return apply(
File "$HOME/.local/lib/python3.8/site-packages/flax/core/scope.py", line 930, in wrapper
raise errors.ApplyScopeInvalidVariablesStructureError(variables)
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
mlxu.run(main)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
server.run()
File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
self.loglikelihood(pre_compile_data, pre_compile_data)
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
logits = hf_model.module.apply(
flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the
variables
(first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)
It seems like something went wrong with "params"
loading at function load_trainstate_checkpoint
in checkpoint.py
, but I couldn't figure where.
Is there someone who knows what's wrong?
There was some misunderstanding; I should have used trainstate_params
instead of params
in my case.
So, the serving script should be like below:
$ python -m EasyLM.models.llama.llama_serve \
--load_llama_config='1b' \
--load_checkpoint="trainstate_params::/path/to/streaming_train_state" \
--tokenizer.vocab_file='/path/to/tokenizer.model' \
--mesh_dim='1,-1,1' \
--dtype='bf16' \
--input_length=1024 \
--seq_length=2048 \
--lm_server.batch_size=4 \
--lm_server.port=8888 \
--lm_server.pre_compile='all'
Sorry for reopen, I thought it'd be better to keep this opened until the dependency deprecation is solved.