djl icon indicating copy to clipboard operation
djl copied to clipboard

pytorch multi classification problems Inconsistent results

Open yz4322gly opened this issue 1 year ago • 0 comments

i`m use tarin dataset to test; two code use one img, but output two different result in python code,print correct result. acc=1:

def predict_img(img_path):
    model = torch.jit.load('./model.pt')
    img = Image.open(img_path)
    res_transforms = v2.Compose([transforms.Resize(232, interpolation=InterpolationMode.BILINEAR),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 transforms.ConvertImageDtype(torch.float),
                                 transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
    img = res_transforms(img).unsqueeze(0)
    outputs = model(img)
    _, prediction = torch.max(outputs, 1)
    print(prediction)

predict_img("./resources/aoi/train/脏污/12_1794_27.851_7.364_O_320_20230920_155110.jpg")

but in java code,acc = 0.8

        List<String> herbNames = Arrays.asList("S型", "OK", "凸点", "残胶", "洗剂残留", "絮状物", "脏污", "透明块状脏污", "长瘤");

        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                //下面的transform根据自己的改
                .addTransform(new Resize(232))
                .addTransform(new CenterCrop(224, 224))
                .addTransform(new ToTensor())
                .addTransform(new Normalize(
                        new float[]{0.485f, 0.456f, 0.406f},
                        new float[]{0.229f, 0.224f, 0.225f}))
                //如果你的模型最后一层没有经过softmax就启用它
                .optApplySoftmax(true)
                //最终显示概率最高的5个
                .optTopK(3)
                .optSynset(herbNames)
                .build();

        Model model = Model.newInstance("resNext", Device.cpu());
        model.load(new ClassPathResource("model.pt").getInputStream());
        Predictor<Image, Classifications> predictor = model.newPredictor(translator);

        Image img = ImageFactory.getInstance().fromInputStream(new ClassPathResource("12_1794_27.851_7.364_O_320_20230920_155110.jpg").getInputStream());
        img.getWrappedImage();
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);

pytorch version = 2.1.1 jdl version = 0.26.0-SNAPSHOT

yz4322gly avatar Dec 22 '23 03:12 yz4322gly