elasticdl
elasticdl copied to clipboard
Add section Prediction in high-level-design.md
following the discussion here https://github.com/wangkuiyi/elasticdl/pull/815#discussion_r298789944
For prediction, there can be two kinds of prediction calls.
- Take a model input or a batch of inputs, returns the model outputs.
elasticdl.predict(trained_model, input_data)
This is a blocking call, and users can use it in its inference service.
- Data from RecordIO files, and stores the output to a user specified location
elasticdl.predict(trained_model, recordio_data, output_location)
This call submits the prediction job to a Kubernetes cluster and distributed prediction is used. This may be used by sqlflow. From the previous document, we are going to implement this prediction call.
Let look back to see how the data is processed in the training process.
rd_data = RecordIO data read from file
model_input, label = input_fn(rd_data)
model_output = model(model_input, training = True)
batch_loss = loss(model_output, label)
then compute gradient, update model
For the prediction for RecordIO data, there are some differences from the training process.
- RecordIO data does not have label, so the decoding in
input_fn
is different. - If there are data augmentation ops in
input_fn
, there ops should not be used in prediction. So a newinput_fn
is needed for the prediction, orinput_fn
should take an argumenttraining
.
rd_data = RecordIO data read from file
model_input = input_fn(rd_data, training=False)
model_output = model(model_input, training = False)
Also, the user may want to process the model output and save it. For example, Resnet50-Imagenet model's output is the output of a softmax layer, which is a 1000-d tensor. Instead of saving this 1000-d tensor, the user might want to just save the top-1 or top-5 predictions
top_k_predictions = top_k_index(model_output)
Should we support this post-processing or the user has to save the original model output.
@skydoorkai Yes I think users may want to post-process before saving the output so they won't need to iterate through the original model output again.
I added the setup for prediction task support in https://github.com/wangkuiyi/elasticdl/pull/823. However, there are still remaining work that needs to be discussed further:
- How to separate input_fn for different tasks? Would something like the above mentioned
input_fn(rd_data, training=False)
satisfy? If it's evaluation task, thelabels
part ininput_fn
can just beNone
. - Whether we need to send prediction results from each worker to master before writing results to destination? It might be expensive to send prediction result to master but users might need to perform some sort of aggregation/analysis before writing the prediction results out.
- How users configure where/how to write prediction results? We can probably just expose a function for them to do this.
- What does the prediction output look like? See https://github.com/wangkuiyi/elasticdl/pull/815#discussion_r298789944 for related discussion.