djl icon indicating copy to clipboard operation
djl copied to clipboard

Error issue using BatchPredict API

Open suntingfeng opened this issue 6 months ago • 1 comments

Hi, I use BatchPredict API for batch prediction, and when the input data shape is NDList size: 1, 0 : (2, 64) int32 or NDList size: 2 0 : (64) int32,1 : (64) int32, it will appear RuntimeError: The size of tensor a (64) must match the size of tensor b (43) at non-singleton dimension 2, but shape is NDList size: 1 0 : (64) int32, it's ok. So I would like to ask if batchPredict can perform batch prediction?If possible, what format should be entered for the shape?

Thank you very much!

Criteria<NDList, Float> criteria =
                Criteria.builder()
                        .setTypes(NDList.class, Float.class)
                        .optModelPath(Paths.get(modelPath))
                        .optTranslator(new MyTranslator())
                        .optEngine("PyTorch")
                        .build();
ZooModel<NDList, Float> model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor();


public static class MyTranslator implements Translator<NDList, Float> {
        @Override
        public Batchifier getBatchifier() {
            return Batchifier.STACK;
        }
        @Override
        public NDList processInput(TranslatorContext translatorContext, NDList inputs) { return inputs; }
        @Override
        public Float processOutput(TranslatorContext ctx, NDList outputs) {
            return outputs.get(0).getFloat();
         }
    }

predictor.batchPredict(Collections.singletonList(inputs));

suntingfeng avatar Aug 09 '24 03:08 suntingfeng