EasyLM icon indicating copy to clipboard operation
EasyLM copied to clipboard

Serving errors: deprecated dependencies and structure error

Open sjw8793 opened this issue 1 year ago • 2 comments

When I try to serve LLaMA with v3_8 TPU as suggested in example script, there were some errors.

Environment

  • TPU: v3-8
  • Software: tpu-vm-base

Command

$ git clone https://github.com/young-geng/EasyLM
$ cd EasyLM
$ ./scripts/tpu_vm_setup.sh
$
$ python -m EasyLM.models.llama.llama_train \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --total_steps=500 \
    --log_freq=50 \
    --load_llama_config='1b' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --save_model_freq=100 \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=1e-3 \
    --optimizer.adamw_optimizer.end_lr=1e-4 \
    --optimizer.adamw_optimizer.lr_warmup_steps=10 \
    --optimizer.adamw_optimizer.lr_decay_steps=100 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.path='/path/to/dataset.jsonl' \
    --train_dataset.json_dataset.seq_length=1024 \
    --train_dataset.json_dataset.batch_size=64 \
    --train_dataset.json_dataset.tokenizer_processes=1 \
    --checkpointer.save_optimizer_state=True \
    --checkpointer.float_dtype=bf16 \
    --logger.online=False \
    --logger.output_dir="~/ellama_checkpoints/" \
|& tee $HOME/output1107_wiki.txt 
$ 
$ python -m EasyLM.models.llama.llama_serve \
    --load_llama_config='1b' \
    --load_checkpoint="params::/path/to/streaming_train_state" \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --input_length=1024 \
    --seq_length=2048 \
    --lm_server.batch_size=4 \
    --lm_server.port=8888 \
    --lm_server.pre_compile='all'

1. Deprecation warning

ImportError: cannot import name 'soft_unicode' from 'markupsafe' ImportError: Pandas requires version '3.0.0' or newer of 'jinja2'

These can be solved by adding 2 lines to tpu_requirements.txt

markupsafe==2.0.1
jinja2~=3.0.0

DeprecationWarning: concurrency_count has been deprecated. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via max_threads in launch().

I was able to solve this by deleting concurrency_count=1 in serving.py, line 403. According to Gradio v4.0.0 changelog, concurrency_count is removed and can be replaced with concurrency_limit. As I'm not exactly understanding what it supposed to do and it's set to 1 by default, I just removed it.

2. Structure error

However, when I solve deprecation errors above, this error appears:

Error Log

I1107 06:16:48.996244 140573565926464 mesh_utils.py:260] Reordering mesh to physical ring order on single-tray TPU v2/v3.
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name user_fn already exists, using user_fn_1
  warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_1
  warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_2
  warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
Traceback (most recent call last):
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
    mlxu.run(main)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
    server.run()
  File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
    self.loglikelihood(pre_compile_data, pre_compile_data)
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
    loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 775, in infer_params
    return common_infer_params(pjit_info_args, *args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
    logits = hf_model.module.apply(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/flax/linen/module.py", line 1511, in apply
    return apply(
  File "$HOME/.local/lib/python3.8/site-packages/flax/core/scope.py", line 930, in wrapper
    raise errors.ApplyScopeInvalidVariablesStructureError(variables)
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e.  {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
    mlxu.run(main)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
    server.run()
  File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
    self.loglikelihood(pre_compile_data, pre_compile_data)
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
    loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
    logits = hf_model.module.apply(

flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the variables (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)

It seems like something went wrong with "params" loading at function load_trainstate_checkpoint in checkpoint.py, but I couldn't figure where. Is there someone who knows what's wrong?

sjw8793 avatar Nov 07 '23 09:11 sjw8793

There was some misunderstanding; I should have used trainstate_params instead of params in my case. So, the serving script should be like below:

$ python -m EasyLM.models.llama.llama_serve \
    --load_llama_config='1b' \
    --load_checkpoint="trainstate_params::/path/to/streaming_train_state" \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --input_length=1024 \
    --seq_length=2048 \
    --lm_server.batch_size=4 \
    --lm_server.port=8888 \
    --lm_server.pre_compile='all'

sjw8793 avatar Nov 08 '23 06:11 sjw8793

Sorry for reopen, I thought it'd be better to keep this opened until the dependency deprecation is solved.

sjw8793 avatar Nov 08 '23 08:11 sjw8793