single worker works but multi worker or setting worker =2 hangs with keras model
multiworker support of kserve ends to hanging/deadlock when trying to predict with the Keras based custom model.
if using worker=1 its working fine and we are getting prediction but increasing worker just hangs the system
https://github.com/kserve/kserve/issues/1803
@Agarwal-Saurabh kserve simply fork the processes for multiple workers, you might need to check your custom code if there is potentially locking, for example sharing the tf session among workers.
@yuzisun I work with @Agarwal-Saurabh and we have tried many variations of passing same tf.session() among different child processes. I am attaching a rough code to help you understand the flow.
`def load_pose_estimation_model(num_classes, num_connections, prev_trained_model, darknet = None, image_input = None):
session = tf.Session(graph=tf.Graph())
with session.graph.as_default():
K.set_session(session)
by_name = True
if image_input is None and darknet is None:
image_input = Input(shape=(None, None, 6))
darknet = Model(image_input, darknet_body(image_input))
by_name = False
key_value_body = yolo4_key_value_matching(image_input, darknet, num_classes, num_connections)
model = Model(darknet.input, [*key_value_body.output])
model.load_weights(prev_trained_model, by_name = by_name, skip_mismatch = True)
print("loaded prev trained model's weights from : ", prev_trained_model)
model._make_predict_function()
session.graph.finalize()
return model, session`
` class ModelService(kserve.Model):
"""Class to serve model predictions."""
def __init__(self,model_args:dict, bucket_args:dict, file_args:dict, aws_creds: dict):
super().__init__(bucket_args['model_uid'])
self.model_args = model_args
self.bucket_args = bucket_args
self.file_args = file_args
self.aws_creds = aws_creds
def load(self):
"""load training artifacts"""
#load the models
model, paf_session = load_pose_estimation_model(**self.model_args)
self.model_args['pose_model'] = model
self.model_args['paf_session'] = paf_session
logger.info("Model Initialised successfully")
def predict(self, request:dict) -> Dict:
"""Conduct inference"""
assert request is not None
try:
doc_file_str = request['instance']
img_path = uuid.uuid4().hex + '.jpg'
with open(img_path, "wb") as f:
f.write(base64.b64decode(doc_file_str))
except Exception as e:
raise TypeError("Failed to DownloadData from Data API {e}")
result = predict_kvp(img_path = img_path, **self.model_args)
os.remove(img_path)
return {"predictions": result}`
` def serve():
model_service = ModelService(
model_args = model_args,
aws_creds = aws_creds,
bucket_args = bucket_args,
file_args = file_args
)
model_service.load()
kserve.ModelServer(workers=2, http_port = 8080).start([model_service])
if name == "main":
logger.info(model_args)
serve()`
Now the entire process hangs during the predict method. Following the code snippet for inference.
`def detect_image(model, img, ref_image, paf_session, input_size = (608, 608)):
original_width, original_height = img.shape[1], img.shape[0]
img = cv2.resize(img, input_size)
img = img / 255.0
ref_image = cv2.resize(ref_image, input_size)
ref_image = ref_image / 255.0
img = np.expand_dims(img, 0)
ref_image = np.expand_dims(ref_image, 0)
print('reached nested 1')
with paf_session.graph.as_default():
print('reached nested 2')
K.set_session(paf_session)
heatmaps, connections = model.predict([np.concatenate([img, ref_image], axis = -1)])
return heatmaps
`
In the above code, the process hangs after 'reached nested 2'. We are unable to make the prediction using keras model with tensorflow backend. The above code also include the session sharing which enables multiprocessing in keras. This code works perfectly with a FLASK API.
@Agarwal-Saurabh Thanks for the detailed code! I think the issue is that the model.load is called before model server forks the process, so multiple processes are effectively sharing the same session object hence deadlock.
@yuzisun Thanks for the prompt response. Actually we want to share the same session with multiple workers with aim to load the model only once. It is more of a multi-threading approach where a single loaded model is used for inference concurrently.
We were able to do so in FLASK with multi-threading. Please let us know if this is possible in kserve. Thanks for your help.
@harshyadav17 Tornado does not use multi-threads, it makes inference asynchronous instead which can scale to > 10k connections easily while multi-threaded just can't. I am not sure how multi-threads can help you for running the CPU-intensive inference especially with the python GIL.