jetstream-pytorch
jetstream-pytorch copied to clipboard
Empty response returned for prompt responses when using run_server_with_ray.py and batch_size > 1
Sending multiple prompts to the server, only the first prompt is able to return any results. Requests after the first one would only return an empty response.
I've tried 3 different ways to bring up the server (all using interleave singlehost on a TPU v4):
python run_interactive.py --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/home/ray/jetstream-pytorch/tokenizer.model --checkpoint_path=/home/ray/jetstream-pytorch/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
No issues.
python run_server.py --model_name=llama-2 --size=7b --batch_size=32 --max_cache_length=2048 --tokenizer_path=/home/ray/jetstream-pytorch/tokenizer.model --checkpoint_path=/home/ray/jetstream-pytorch/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
No issues.
python run_server_with_ray.py --tpu_chips=16 --model_name=llama-2 --size=7b --batch_size=32 --max_cache_length=2048 --tokenizer_path=/home/ray/jetstream-pytorch/tokenizer.model --checkpoint_path=/home/ray/jetstream-pytorch/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
This would return the above problem. Debugging the code further, it seems like the stop token was returned from the model:
I0627 11:25:18.014765 137073235306240 orchestrator.py:741] >>>>data: {data}
2024-06-27 11:25:18,046 - root - INFO - Generate engine 0 step 202 - slots free : 31 / 32, took 40520.31ms
I0627 11:25:18.046449 137073243698944 orchestrator.py:678] Generate engine 0 step 202 - slots free : 31 / 32, took 40520.31ms
2024-06-27 11:25:18,046 - root - INFO - Generate thread making a decision with: prefill_backlog=0 generate_free_slots=31
I0627 11:25:18.046644 137073243698944 orchestrator.py:588] Generate thread making a decision with: prefill_backlog=0 generate_free_slots=31
2024-06-27 11:25:18,047 - root - INFO - Complete [False], slot_tokens: [[13]], slot_lengths: [2]
I0627 11:25:18.047123 137073235306240 token_utils.py:194] Complete [False], slot_tokens: [[13]], slot_lengths: [2]
2024-06-27 11:25:18,047 - root - INFO - Sample idx: 0 Speculation idx: 0 Token: 13
I0627 11:25:18.047416 137073235306240 token_utils.py:209] Sample idx: 0 Speculation idx: 0 Token: 13
2024-06-27 11:25:18,047 - root - INFO - Return samples [ReturnSample(text=['<0x0A>'], token_ids=[13])]
I0627 11:25:18.047532 137073235306240 token_utils.py:230] Return samples [ReturnSample(text=['<0x0A>'], token_ids=[13])]
2024-06-27 11:25:18,047 - root - INFO - >>>>results: [ReturnSample(text=['<0x0A>'], token_ids=[13])] complete: [False]
I0627 11:25:18.047641 137073235306240 orchestrator.py:725] >>>>results: [ReturnSample(text=['<0x0A>'], token_ids=[13])] complete: [False]
2024-06-27 11:25:18,047 - root - INFO - Detokenizing generate step 201 took 1.05ms
I0627 11:25:18.047804 137073235306240 orchestrator.py:734] Detokenizing generate step 201 took 1.05ms
2024-06-27 11:25:18,067 - root - INFO - Generate engine 0 step 203 - slots free : 31 / 32, took 20.82ms
I0627 11:25:18.067497 137073243698944 orchestrator.py:678] Generate engine 0 step 203 - slots free : 31 / 32, took 20.82ms
2024-06-27 11:25:18,068 - root - INFO - Complete [False], slot_tokens: [[0]], slot_lengths: [3]
I0627 11:25:18.068114 137073235306240 token_utils.py:194] Complete [False], slot_tokens: [[0]], slot_lengths: [3]
2024-06-27 11:25:18,068 - root - INFO - Sample idx: 0 Speculation idx: 0 Token: 0
I0627 11:25:18.068206 137073235306240 token_utils.py:209] Sample idx: 0 Speculation idx: 0 Token: 0
2024-06-27 11:25:18,068 - root - INFO - >>>complete: tok_id: 0 stop_tokens: {0, 2} valid: 1
I0627 11:25:18.068270 137073235306240 token_utils.py:216] >>>complete: tok_id: 0 stop_tokens: {0, 2} valid: 1
2024-06-27 11:25:18,068 - root - INFO - Return samples [ReturnSample(text=[], token_ids=[])]
I0627 11:25:18.068324 137073235306240 token_utils.py:230] Return samples [ReturnSample(text=[], token_ids=[])]
2024-06-27 11:25:18,068 - root - INFO - >>>>results: [ReturnSample(text=[], token_ids=[])] complete: [ True]
I0627 11:25:18.068485 137073235306240 orchestrator.py:725] >>>>results: [ReturnSample(text=[], token_ids=[])] complete: [ True]
2024-06-27 11:25:18,068 - root - INFO - Detokenizing generate step 202 took 0.72ms
I0627 11:25:18.068594 137073235306240 orchestrator.py:734] Detokenizing generate step 202 took 0.72ms
2024-06-27 11:25:18,088 - root - INFO - Generate engine 0 step 204 - slots free : 31 / 32, took 21.10ms
I0627 11:25:18.088757 137073243698944 orchestrator.py:678] Generate engine 0 step 204 - slots free : 31 / 32, took 21.10ms
2024-06-27 11:25:18,089 - root - INFO - Detokenizing generate step 203 took 0.04ms
I0627 11:25:18.089094 137073235306240 orchestrator.py:734] Detokenizing generate step 203 took 0.04ms
This only repros with run_server_with_ray
, and only if the batch_size
is set to greater than 1.