djl
djl copied to clipboard
[FATAL] extensions/tokenizers/rust/src/lib.rs crashes the process
DJL version 0.27.0
When a null-value sequence is passed to a text-encoder model, the JAVA process crashes.
Error output:
thread '<unnamed>' panicked at src/lib.rs:217:14:
Couldn't get java string!: NullPtr("get_string obj argument")
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
fatal runtime error: failed to initiate panic, error 5
Example:
public final class Encoder {
private Encoder() {
}
private static final String uri;
private static final Predictor<String[], float[][]> predictor;
static {
try {
uri = Objects.requireNonNull(Encoder.class.getClassLoader().getResource("models/bge-m3.zip")).toURI().toString().replaceAll("jar:file.*!", "jar:");
predictor = getPredictor();
} catch (ModelNotFoundException | MalformedModelException | IOException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
public static Predictor<String[], float[][]> getPredictor()
throws ModelNotFoundException, MalformedModelException, IOException {
if (predictor != null) {
return predictor;
}
Criteria<String[], float[][]> criteria =
Criteria.builder()
.setTypes(String[].class, float[][].class)
.optModelUrls(uri)
.optEngine("PyTorch")
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.optProgress(new ProgressBar())
.build();
ZooModel<String[], float[][]> model = criteria.loadModel();
Predictor<String[], float[][]> predictor = model.newPredictor();
return predictor;
}
public static void main(String[] args) {
String[] texts = new String[]{null}); // encode null string
predictor.predict(texts);
}
}
lib.rs
let sequence: String = env
.get_string(&input)
.expect("Couldn't get java string!")
.into();
jnienv.rs
pub unsafe fn get_string_unchecked<'other_local: 'obj_ref, 'obj_ref>(
&self,
obj: &'obj_ref JString<'other_local>,
) -> Result<JavaStr<'local, 'other_local, 'obj_ref>> {
non_null!(obj, "get_string obj argument");
JavaStr::from_env(self, obj)
}
pub fn get_string<'other_local: 'obj_ref, 'obj_ref>(
&mut self,
obj: &'obj_ref JString<'other_local>,
) -> Result<JavaStr<'local, 'other_local, 'obj_ref>> {
let string_class = self.find_class("java/lang/String")?;
if !self.is_assignable_from(string_class, self.get_object_class(obj)?)? {
return Err(JniCall(JniError::InvalidArguments));
}
// SAFETY: We check that the passed in Object is actually a java.lang.String
unsafe { self.get_string_unchecked(obj) }
}
@zaobao
We should do better error handling in rust code. Should not just call . expect()