djl
djl copied to clipboard
Error issue using BatchPredict API
Hi, I use BatchPredict API for batch prediction, and when the input data shape is NDList size: 1, 0 : (2, 64) int32 or NDList size: 2 0 : (64) int32,1 : (64) int32, it will appear RuntimeError: The size of tensor a (64) must match the size of tensor b (43) at non-singleton dimension 2, but shape is NDList size: 1 0 : (64) int32, it's ok. So I would like to ask if batchPredict can perform batch prediction?If possible, what format should be entered for the shape?
Thank you very much!
Criteria<NDList, Float> criteria =
Criteria.builder()
.setTypes(NDList.class, Float.class)
.optModelPath(Paths.get(modelPath))
.optTranslator(new MyTranslator())
.optEngine("PyTorch")
.build();
ZooModel<NDList, Float> model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor();
public static class MyTranslator implements Translator<NDList, Float> {
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
@Override
public NDList processInput(TranslatorContext translatorContext, NDList inputs) { return inputs; }
@Override
public Float processOutput(TranslatorContext ctx, NDList outputs) {
return outputs.get(0).getFloat();
}
}
predictor.batchPredict(Collections.singletonList(inputs));