DeepSpeed-MII icon indicating copy to clipboard operation
DeepSpeed-MII copied to clipboard

Block when Call client inference in multiprocessing.Process

Open zhaotyer opened this issue 11 months ago • 3 comments

I tried to integrate mii into tritonserver, but encountered some problems Below is part of my code

class TritonPythonModel:
    def initialize(self, args):
        import mii
        from transformers import AutoTokenizer
        tensor_parallel_size = 1
        cuda_env = env_manager.cuda_visible_devices
        if cuda_env is None:
            from torch.cuda import device_count
            tensor_parallel_size = device_count()
        else:
            tensor_parallel_size = len(cuda_env.split(",")) if cuda_env else 1
        self._model = mii.serve(self.base_model_path, deployment_name="atom", tensor_parallel=tensor_parallel_size)
        self._tokenizer = AutoTokenizer.from_pretrained(self.base_model_path, trust_remote_code=True)

    def execute(self, requests):
        responses = []
        for request in requests:
            self.process_request(request)
        return None
    
    def process_request(self, request):
            # self.create_task(self.mii_response_thread(request.get_response_sender(), request))
            thread = Process(target=self.mii_response_thread,
                                    args=(request.get_response_sender(), request))
            # thread.daemon = True
            thread.start()
           
    def mii_response_thread(self, response_sender, request):
        try:
            import mii
            event_loop = asyncio.new_event_loop()
            asyncio.set_event_loop(event_loop)
            req_desc = pb_utils.get_input_tensor_by_name(request, "JSON")
            req_json = json.loads(req_desc.as_numpy()[0])
            stop = req_json.get('stop', False)
            query, prompt, history, stream, gen_config, response_config, tools, mode  = self.process_input_params(request, req_json)
            client = mii.client('atom')
            output_tokens = []
            def callback(response):
                logger.debug(f"Received: {response[0].generated_text}")
                self.send(response_sender, response[0].generated_text)
                # print(f"Received: {response[0].generated_text} time_last_token={time_last_token}")
                output_tokens.append(response[0].generated_text)
            logger.debug("call mii generate")
            client.generate(prompt, max_new_tokens=4096, streaming_fn=callback)
            logger.info(f"output text is:{''.join(output_tokens)}")
        except Exception as e:
            logger.exception(f"Capture error:{e}")
            self.send_error(response_sender, f"Error occur:{e}")
        finally:
            self.send_final(response_sender)
            # self.handler.ongoing_request_count -= 1

the error is: when i use

 thread = Process(target=self.mii_response_thread,
                                    args=(request.get_response_sender(), request))

mii block at

async for response in getattr(self.stub,
                                      task_methods.method_stream_out)(proto_request):
            yield task_methods.unpack_response_from_proto(response)

when i use

 thread = Thread(target=self.mii_response_thread,
                                    args=(request.get_response_sender(), request))

Able to infer normally, but grpc keeps reporting errors(Does not affect inference but the service is not stable) https://github.com/grpc/grpc/issues/25364

zhaotyer avatar Mar 29 '24 10:03 zhaotyer

I meet the similar case. Here is my code:

def worker(rank, this_model):
    try:
        if this_model is None:
            client = mii.client('qwen')
        else:
            client = this_model
        response = client.generate(["xxx"], max_new_tokens=1024, stop="<|im_end|>", do_sample=False, return_full_text=True)
        print("in worker rank:", rank, " response:", response)
    except Exception as e:
        print(f"Capture error:{e}")
    finally:
        print("final")

model = mii.serve(model_dir, deployment_name="qwen", tensor_parallel=xx, replica_num=replica_num)

job_process = []
for rank in range(0, replica_num):
    if rank == 0:
        job_process.append(threading.Thread(target=worker,args=(rank,model,)))
    else:
        job_process.append(threading.Thread(target=worker,args=(rank,None,)))
for process in job_process:
    process.start()
for process in job_process:
    process.join()

When using threading.Thread, it works well. However, it will be blocked in client.generate if using multiprocessing.Process.

nxznm avatar Apr 26 '24 08:04 nxznm

I meet the similar case. Here is my code:

def worker(rank, this_model):
    try:
        if this_model is None:
            client = mii.client('qwen')
        else:
            client = this_model
        response = client.generate(["xxx"], max_new_tokens=1024, stop="<|im_end|>", do_sample=False, return_full_text=True)
        print("in worker rank:", rank, " response:", response)
    except Exception as e:
        print(f"Capture error:{e}")
    finally:
        print("final")

model = mii.serve(model_dir, deployment_name="qwen", tensor_parallel=xx, replica_num=replica_num)

job_process = []
for rank in range(0, replica_num):
    if rank == 0:
        job_process.append(threading.Thread(target=worker,args=(rank,model,)))
    else:
        job_process.append(threading.Thread(target=worker,args=(rank,None,)))
for process in job_process:
    process.start()
for process in job_process:
    process.join()

When using threading.Thread, it works well. However, it will be blocked in client.generate if using multiprocessing.Process.

Since the threading.Thread is fake in python due to GIL, this code can not make full use of concurrency. It means that I still need multiprocessing.Process to start a new client. However, it does not work well mentioned above.

nxznm avatar Apr 29 '24 02:04 nxznm

I meet the similar case. Here is my code:

def worker(rank, this_model):
    try:
        if this_model is None:
            client = mii.client('qwen')
        else:
            client = this_model
        response = client.generate(["xxx"], max_new_tokens=1024, stop="<|im_end|>", do_sample=False, return_full_text=True)
        print("in worker rank:", rank, " response:", response)
    except Exception as e:
        print(f"Capture error:{e}")
    finally:
        print("final")

model = mii.serve(model_dir, deployment_name="qwen", tensor_parallel=xx, replica_num=replica_num)

job_process = []
for rank in range(0, replica_num):
    if rank == 0:
        job_process.append(threading.Thread(target=worker,args=(rank,model,)))
    else:
        job_process.append(threading.Thread(target=worker,args=(rank,None,)))
for process in job_process:
    process.start()
for process in job_process:
    process.join()

When using threading.Thread, it works well. However, it will be blocked in client.generate if using multiprocessing.Process.

Since the threading.Thread is fake in python due to GIL, this code can not make full use of concurrency. It means that I still need multiprocessing.Process to start a new client. However, it does not work well mentioned above.

I find the official example. Maybe we should start the server and clients like these ways.

nxznm avatar Apr 30 '24 09:04 nxznm