Any methods to support params update?
Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [x] 2. The bug has not been fixed in the latest version.
- [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
Do Lmdeploy has any methods to update the model params or the model instances like vllm?
Reproduction
none
Environment
none
Error traceback
none
LMDeploy has already supported it.
Weights can be updated through api_server's /update_weights endpoint
Is there any reference document?
In addition to updating via the API, does the pipeline also support updating weights? Is there any documentation available?
Sorry, we haven't prepared the user guide yet. Here’s how api_server updates the model weights:
@router.post('/update_weights', dependencies=[Depends(check_api_key)])
def update_params(request: UpdateParamsRequest, raw_request: Request = None):
"""Update weights for the model."""
VariableInterface.async_engine.engine.update_params(request)
return JSONResponse(content=None)
Note that pipeline is an alias for async_engine, so in theory, the pipeline also supports this feature. However, users are responsible for constructing the request object — which is not a trivial task
@RunningLeon Could you help writing the user guide?
How should this request be constructed? And, if it's offline inference, is it possible to update directly from the model instead of converting it to base64 or something similar?
@Auraithm Now we only support update through http api. Here is a simple example.
from lmdeploy.utils import serialize_state_dict
import requests
BASE_URL = 'http://0.0.0.0:24545'
api_key = 'sk-xxx'
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
assert response.status_code == 200, f"response.status_code = {response.status_code}"
What is the version of lmdeploy? I used the version 0.10.1.
await self.middleware_stack(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py", line 186, in __call__
raise exc
File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py", line 164, in __call__
await self.app(scope, receive, _send)
File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/cors.py", line 85, in __call__
await self.app(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/exceptions.py", line 63, in __call__
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
raise exc
File "/usr/local/lib/python3.12/dist-packages/starlette/_exception_handler.py", line 42, in wrapped_app
await app(scope, receive, sender)
File "/usr/local/lib/python3.12/dist-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
await self.app(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/routing.py", line 716, in __call__
await self.middleware_stack(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/routing.py", line 736, in app
await route.handle(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/routing.py", line 290, in handle
await self.app(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/fastapi/routing.py", line 123, in app
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
File "/usr/local/lib/python3.12/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
raise exc
File "/usr/local/lib/python3.12/dist-packages/starlette/_exception_handler.py", line 42, in wrapped_app
await app(scope, receive, sender)
File "/usr/local/lib/python3.12/dist-packages/fastapi/routing.py", line 109, in app
response = await f(request)
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/fastapi/routing.py", line 389, in app
raw_response = await run_endpoint_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/fastapi/routing.py", line 290, in run_endpoint_function
return await run_in_threadpool(dependant.call, **values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/starlette/concurrency.py", line 38, in run_in_threadpool
return await anyio.to_thread.run_sync(func)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/anyio/to_thread.py", line 56, in run_sync
return await get_async_backend().run_sync_in_worker_thread(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/anyio/_backends/_asyncio.py", line 2461, in run_sync_in_worker_thread
return await future
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/anyio/_backends/_asyncio.py", line 962, in run
result = context.run(func, *args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/serve/openai/api_server.py", line 990, in update_params
VariableInterface.async_engine.engine.update_params(request)
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/mp_engine/base.py", line 54, in update_params
return self._collective_rpc('update_params', request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py", line 137, in _collective_rpc
return self.rpc_client.call(func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py", line 249, in call
raise Exception(reply['error'])
Exception: Not Implemented.
@Auraithm hi, sorry for the trouble. You have to change source code to set distributed_executor_backend='ray' for pytorch backend engine before start the server.
You can change the source code for a workaround since this argument is not added in lmdeploy serve api_server cli command.
A PR will fix this later.
export LMDEPLOY_EXECUTOR_BACKEND=ray
lmdeploy serve api_server /inspire/hdd/global_user/liuxiaoran-240108120089/public/SDAR-8B-Chat \
--model-name "SDAR-8B-Chat" \
--server-port 23333 \
--tp 8
``` What needs to be modified? I added this environment variable by referring to the documentation, but I have never used ray. Do I need to modify anything?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid resource handle
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
2025-11-24 05:24:48,353 ERROR worker.py:430 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::RayWorkerWrapper.update_params() (pid=47624, ip=10.247.56.69, actor_id=d1a95221a84a339f49f49b5d01000000, repr=<lmdeploy.pytorch.engine.executor.ray_executor.RayWorkerWrapper object at 0x7f1f0ebc3c20>)
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/executor/base_worker.py", line 124, in update_params
self.model_agent.update_params(request)
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/model_agent.py", line 1028, in update_params
weights = [(k, _construct(v)) for k, v in weights]
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/model_agent.py", line 1021, in _construct
return func(*args).clone()
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/multiprocessing/reductions.py", line 181, in rebuild_cuda_tensor
storage = storage_cls._new_shared_cuda(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/storage.py", line 1457, in _new_shared_cuda
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid resource handle
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
export LMDEPLOY_EXECUTOR_BACKEND=ray lmdeploy serve api_server /inspire/hdd/global_user/liuxiaoran-240108120089/public/SDAR-8B-Chat \ --model-name "SDAR-8B-Chat" \ --server-port 23333 \ --tp 8 ``` What needs to be modified? I added this environment variable by referring to the documentation, but I have never used ray. Do I need to modify anything?
add backend_config.distributed_executor_backend='ray' in here
https://github.com/InternLM/lmdeploy/blob/1984a5d4bd447b9206f9e0f9d2110f45eecf827d/lmdeploy/serve/openai/api_server.py#L1288
export LMDEPLOY_EXECUTOR_BACKEND=ray lmdeploy serve api_server /inspire/hdd/global_user/liuxiaoran-240108120089/public/SDAR-8B-Chat \ --model-name "SDAR-8B-Chat" \ --server-port 23333 \ --tp 8 ``` What needs to be modified? I added this environment variable by referring to the documentation, but I have never used ray. Do I need to modify anything?add
backend_config.distributed_executor_backend='ray'in herelmdeploy/lmdeploy/serve/openai/api_server.py
Line 1288 in 1984a5d
backend_config.enable_mp_engine = True
I added this, but still failed with the error above:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid resource handle
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
2025-11-24 05:24:48,353 ERROR worker.py:430 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::RayWorkerWrapper.update_params() (pid=47624, ip=10.247.56.69, actor_id=d1a95221a84a339f49f49b5d01000000, repr=<lmdeploy.pytorch.engine.executor.ray_executor.RayWorkerWrapper object at 0x7f1f0ebc3c20>)
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/executor/base_worker.py", line 124, in update_params
self.model_agent.update_params(request)
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/model_agent.py", line 1028, in update_params
weights = [(k, _construct(v)) for k, v in weights]
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/lmdeploy/pytorch/engine/model_agent.py", line 1021, in _construct
return func(*args).clone()
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/multiprocessing/reductions.py", line 181, in rebuild_cuda_tensor
storage = storage_cls._new_shared_cuda(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/storage.py", line 1457, in _new_shared_cuda
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid resource handle
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
···
export LMDEPLOY_EXECUTOR_BACKEND=ray lmdeploy serve api_server /inspire/hdd/global_user/liuxiaoran-240108120089/public/SDAR-8B-Chat \ --model-name "SDAR-8B-Chat" \ --server-port 23333 \ --tp 8 ``` What needs to be modified? I added this environment variable by referring to the documentation, but I have never used ray. Do I need to modify anything?add
backend_config.distributed_executor_backend='ray'in herelmdeploy/lmdeploy/serve/openai/api_server.py
Line 1288 in 1984a5d
backend_config.enable_mp_engine = True
import torch
from typing import List, Dict
from transformers import AutoModelForCausalLM
import requests
from lmdeploy.utils import serialize_state_dict
# 1. 启动 API 服务
# lmdeploy serve /path/to/model --backend pytorch --server-port 23333
# 2. 加载训练后的模型
trained_model = AutoModelForCausalLM.from_pretrained(
"xxx/SDAR-8B-Chat",
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="cuda"
)
# 3. 准备权重并分段
state_dict = trained_model.state_dict()
# 将 state_dict 分段(可根据需要调整分段策略)
segmented_state_dict: List[Dict[str, torch.Tensor]] = [state_dict]
# 4. 发送权重更新请求
BASE_URL = 'http://0.0.0.0:23333'
api_key = 'sk-xxx'
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
print(f"Serializing segment {seg_idx + 1}/{num_segment}...")
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
print(f"Sending segment {seg_idx + 1}/{num_segment}...")
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
if response.status_code != 200:
print(f"Error response status: {response.status_code}")
print(f"Error response text: {response.text}")
assert False, f"Failed to update weights: status_code={response.status_code}"
print(f"Segment {seg_idx + 1}/{num_segment} updated successfully")
print("All weights updated successfully!")
This is my code.
@Auraithm hi, pls. follow this.
# step 1 sleep
response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))
assert response.status_code == 200, response.status_code
# step 2 update params
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
assert response.status_code == 200, f"response.status_code = {response.status_code}"
# step 3 wake up
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights', 'kv_cache']))
assert response.status_code == 200, response.status_code
# step 4 infer as normal
@Auraithm hi, pls. follow this.
# step 1 sleep response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=1)) assert response.status_code == 200, response.status_code # step 2 update params segmented_state_dict: List[Dict[str, torch.Tensor]] = ... num_segment = len(segmented_state_dict) for seg_idx in range(num_segment): serialized_data = serialize_state_dict(segmented_state_dict[seg_idx]) data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1) response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) assert response.status_code == 200, f"response.status_code = {response.status_code}" # step 3 wake up response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights', 'kv_cache'])) assert response.status_code == 200, response.status_code # step 4 infer as normal
Thank you, I've succeeded! Lmdeploy is currently the only inference engine that supports the SDAR model, so we developed a related RL training framework based on it. However, we've been struggling with updates in this area, and I think with this solution, our training efficiency will be further improved.
@Auraithm Hi, pls refer to https://lmdeploy.readthedocs.io/en/latest/advance/update_weights.html
@Auraithm Hi, pls refer to https://lmdeploy.readthedocs.io/en/latest/advance/update_weights.html
Thank you. If I want to use proxy_server to deploy on multi nodes like https://lmdeploy.readthedocs.io/en/latest/llm/proxy_server.html, and I want to update weights,which url is I need to request for? The proxy_url or the single api_server?
@Auraithm The single api_server url.