djl icon indicating copy to clipboard operation
djl copied to clipboard

[FATAL] extensions/tokenizers/rust/src/lib.rs crashes the process

Open zaobao opened this issue 9 months ago • 1 comments

DJL version 0.27.0

When a null-value sequence is passed to a text-encoder model, the JAVA process crashes.

Error output:

thread '<unnamed>' panicked at src/lib.rs:217:14:
Couldn't get java string!: NullPtr("get_string obj argument")
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
fatal runtime error: failed to initiate panic, error 5

Example:

public final class Encoder {

    private Encoder() {
    }

    private static final String uri;

    private static final Predictor<String[], float[][]> predictor;

    static {
        try {
            uri = Objects.requireNonNull(Encoder.class.getClassLoader().getResource("models/bge-m3.zip")).toURI().toString().replaceAll("jar:file.*!", "jar:");
            predictor = getPredictor();
        } catch (ModelNotFoundException | MalformedModelException | IOException | URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }

    public static Predictor<String[], float[][]> getPredictor()
            throws ModelNotFoundException, MalformedModelException, IOException {
        if (predictor != null) {
            return predictor;
        }
        Criteria<String[], float[][]> criteria =
                Criteria.builder()
                        .setTypes(String[].class, float[][].class)
                        .optModelUrls(uri)
                        .optEngine("PyTorch")
                        .optTranslatorFactory(new TextEmbeddingTranslatorFactory())
                        .optProgress(new ProgressBar())
                        .build();
        ZooModel<String[], float[][]> model = criteria.loadModel();
        Predictor<String[], float[][]> predictor = model.newPredictor();
        return predictor;
    }

    public static void main(String[] args) {
        String[] texts = new String[]{null}); // encode null string
        predictor.predict(texts);
    }
}

lib.rs

    let sequence: String = env
        .get_string(&input)
        .expect("Couldn't get java string!")
        .into();

jnienv.rs

    pub unsafe fn get_string_unchecked<'other_local: 'obj_ref, 'obj_ref>(
        &self,
        obj: &'obj_ref JString<'other_local>,
    ) -> Result<JavaStr<'local, 'other_local, 'obj_ref>> {
        non_null!(obj, "get_string obj argument");
        JavaStr::from_env(self, obj)
    }

    pub fn get_string<'other_local: 'obj_ref, 'obj_ref>(
        &mut self,
        obj: &'obj_ref JString<'other_local>,
    ) -> Result<JavaStr<'local, 'other_local, 'obj_ref>> {
        let string_class = self.find_class("java/lang/String")?;
        if !self.is_assignable_from(string_class, self.get_object_class(obj)?)? {
            return Err(JniCall(JniError::InvalidArguments));
        }

        // SAFETY: We check that the passed in Object is actually a java.lang.String
        unsafe { self.get_string_unchecked(obj) }
    }

zaobao avatar May 11 '24 11:05 zaobao

@zaobao

We should do better error handling in rust code. Should not just call . expect()

frankfliu avatar May 11 '24 20:05 frankfliu