djl icon indicating copy to clipboard operation
djl copied to clipboard

isTuple() INTERNAL ASSERT FAILED at "..\\..\\aten\\src\\ATen/core/ivalue_inl.h":1101, please report a bug to PyTorch. Expected Tuple but got String

Open Nanran1220 opened this issue 7 months ago • 3 comments

ai.djl.translate.TranslateException: ai.djl.engine.EngineException: isTuple() INTERNAL ASSERT FAILED at "..\..\aten\src\ATen/core/ivalue_inl.h":1101, please report a bug to PyTorch. Expected Tuple but got String at ai.djl.inference.Predictor.batchPredict(Predictor.java:186) ~[api-0.15.0.jar:na] at ai.djl.inference.Predictor.predict(Predictor.java:123) ~[api-0.15.0.jar:na] at com.ccf.image.service.FeatureExtractor.extract(FeatureExtractor.java:69) ~[classes/:na] at com.ccf.image.controller.ImageController.extractVector(ImageController.java:31) ~[classes/:na] at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ~[na:na] at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) ~[na:na] at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) ~[na:na] at java.base/java.lang.reflect.Method.invoke(Method.java:566) ~[na:na] at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:205) ~[spring-web-5.3.27.jar:5.3.27] at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:150) ~[spring-web-5.3.27.jar:5.3.27] at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:117) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:895) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:808) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1072) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:965) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1006) ~[spring-webmvc-5.3.27.jar:5.3.27] at org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:909) ~[spring-webmvc-5.3.27.jar:5.3.27] at javax.servlet.http.HttpServlet.service(HttpServlet.java:555) ~[tomcat-embed-core-9.0.75.jar:4.0.FR] at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:883) ~[spring-webmvc-5.3.27.jar:5.3.27] at javax.servlet.http.HttpServlet.service(HttpServlet.java:623) ~[tomcat-embed-core-9.0.75.jar:4.0.FR] at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:209) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:153) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:51) ~[tomcat-embed-websocket-9.0.75.jar:9.0.75] at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:178) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:153) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100) ~[spring-web-5.3.27.jar:5.3.27] at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:117) ~[spring-web-5.3.27.jar:5.3.27] at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:178) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:153) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93) ~[spring-web-5.3.27.jar:5.3.27] at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:117) ~[spring-web-5.3.27.jar:5.3.27] at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:178) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:153) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201) ~[spring-web-5.3.27.jar:5.3.27] at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:117) ~[spring-web-5.3.27.jar:5.3.27] at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:178) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:153) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:167) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:90) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:481) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:130) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:93) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:74) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:343) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:390) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:63) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:926) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1791) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:52) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.tomcat.util.threads.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1191) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.tomcat.util.threads.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:659) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61) ~[tomcat-embed-core-9.0.75.jar:9.0.75] at java.base/java.lang.Thread.run(Thread.java:834) ~[na:na] Caused by: ai.djl.engine.EngineException: isTuple() INTERNAL ASSERT FAILED at "..\..\aten\src\ATen/core/ivalue_inl.h":1101, please report a bug to PyTorch. Expected Tuple but got String at ai.djl.pytorch.jni.PyTorchLibrary.moduleForward(Native Method) ~[pytorch-engine-0.15.0.jar:na] at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:46) ~[pytorch-engine-0.15.0.jar:na] at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:126) ~[pytorch-engine-0.15.0.jar:na] at ai.djl.nn.AbstractBlock.forward(AbstractBlock.java:126) ~[api-0.15.0.jar:na] at ai.djl.nn.Block.forward(Block.java:122) ~[api-0.15.0.jar:na] at ai.djl.inference.Predictor.predictInternal(Predictor.java:137) ~[api-0.15.0.jar:na] at ai.djl.inference.Predictor.batchPredict(Predictor.java:177) ~[api-0.15.0.jar:na]

@Service @Slf4j public class FeatureExtractor {

private static final String MODEL_NAME = "resnet50";

public float[] extract(MultipartFile imageFile) throws Exception {

    Criteria<Image, float[]> criteria1 = Criteria.builder()

// .optModelName(MODEL_NAME) .setTypes(Image.class, float[].class) // 使用 TorchVision 提供的 ResNet50 预训练模型路径 .optModelPath(Paths.get("F:\PycharmProjects\yolov5\resnet50.pt")) // .optModelUrls("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/pytorch/resnet50/0.0.1/") .optTranslator(new MyTranslator()) // 必须启用自定义 Translator .optEngine("PyTorch") // 指定引擎 .optProgress(new ProgressBar()) .build();

    Criteria<Image, Classifications> criteria2 = Criteria.builder()

// .optModelName(MODEL_NAME) .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("F:\PycharmProjects\yolov5\resnet50_features.pt")) .optEngine("PyTorch") // 指定引擎 .optProgress(new ProgressBar()) .build();

    Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            .optEngine("PyTorch")
            .optArtifactId("resnet50")
            .optFilter("backbone", "true")  // 提取特征而非分类
            .build();

    try (ZooModel<Image, float[]> model = criteria1.loadModel();
         Predictor<Image, float[]> predictor = model.newPredictor()) {

        // 1. 预处理图像
        Image image = preprocessImage(imageFile);

        // 2. 创建推理器

// Predictor<Image, float[]> predictor = model.newPredictor();

        // 3. 执行推理
        return predictor.predict(image);
    }
}

private Image preprocessImage(MultipartFile file) throws IOException {
    try (InputStream is = file.getInputStream()) {
        return ImageFactory.getInstance().fromInputStream(is);
    }
}

public class MyTranslator implements Translator<Image, float[]> {

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);

        // 正确的预处理顺序:
        array = array.transpose(2, 0, 1);  // HWC -> CHW
        array = array.div(255f);  // 归一化到 [0, 1]
        array = NDImageUtils.normalize(array,
                new float[]{0.485f, 0.456f, 0.406f},
                new float[]{0.229f, 0.224f, 0.225f}  // ImageNet 均值标准差
        );
        array = array.expandDims(0);  // 添加 batch 维度

        return new NDList(array);
    }

    @Override
    public float[] processOutput(TranslatorContext ctx, NDList list) {

        // 遍历所有输出项,找到特征张量
        for (NDArray arr : list) {
            System.out.println("Output item shape: " + arr.getShape());
        }

        // 假设特征张量是第一个元素且是四维的 [1, 2048, 1, 1]
        NDArray output = list.get(0);
        output = output.squeeze();  // 去除冗余维度 -> [2048]
        return output.toFloatArray();

        // 提取全连接层前的特征向量(ResNet50输出为2048维)

// NDArray output = list.singletonOrThrow(); // return output.squeeze(0).toFloatArray(); // 去除batch维度

// NDArray output = list.singletonOrThrow(); // output = output.squeeze(new int[] {0, 2, 3}); // 去除冗余维度 // return output.toFloatArray(); }

    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }
}

}

Nanran1220 avatar Apr 18 '25 02:04 Nanran1220