djl
djl copied to clipboard
pytorch multi classification problems Inconsistent results
i`m use tarin dataset to test; two code use one img, but output two different result in python code,print correct result. acc=1:
def predict_img(img_path):
model = torch.jit.load('./model.pt')
img = Image.open(img_path)
res_transforms = v2.Compose([transforms.Resize(232, interpolation=InterpolationMode.BILINEAR),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
img = res_transforms(img).unsqueeze(0)
outputs = model(img)
_, prediction = torch.max(outputs, 1)
print(prediction)
predict_img("./resources/aoi/train/脏污/12_1794_27.851_7.364_O_320_20230920_155110.jpg")
but in java code,acc = 0.8
List<String> herbNames = Arrays.asList("S型", "OK", "凸点", "残胶", "洗剂残留", "絮状物", "脏污", "透明块状脏污", "长瘤");
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
//下面的transform根据自己的改
.addTransform(new Resize(232))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[]{0.485f, 0.456f, 0.406f},
new float[]{0.229f, 0.224f, 0.225f}))
//如果你的模型最后一层没有经过softmax就启用它
.optApplySoftmax(true)
//最终显示概率最高的5个
.optTopK(3)
.optSynset(herbNames)
.build();
Model model = Model.newInstance("resNext", Device.cpu());
model.load(new ClassPathResource("model.pt").getInputStream());
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
Image img = ImageFactory.getInstance().fromInputStream(new ClassPathResource("12_1794_27.851_7.364_O_320_20230920_155110.jpg").getInputStream());
img.getWrappedImage();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
pytorch version = 2.1.1 jdl version = 0.26.0-SNAPSHOT