rust-xgboost
rust-xgboost copied to clipboard
help on safe implementation of XGBoosterGetModelRaw?
hey -
I was trying to implement the other side of the load/save from buffer, XGBoosterGetModelRaw, but I'm stuck, and thought maybe you could help.
This is where I am (in Booster impl):
/// Returns a `Vec<u8>` with model weights.
pub fn to_vec(&self) -> XGBResult<Vec<u8>> {
debug!("Writing Booster to_vec");
let mut out_len = 0;
let mut out_dptr = ptr::null_mut();
xgb_call!(xgboost_sys::XGBoosterGetModelRaw(self.handle, &mut out_len, out_dptr))?;
// let bytes: &[u8] = unsafe {
// let length: u64 = *(out_len as *const _);
// std::slice::from_raw_parts(out_dptr as *const _, length as usize)
// };
// let mut out: Vec<u8> = vec![0u8; bytes.len()];
// out[..].copy_from_slice(bytes);
// Ok(out)
}
The commented out section is because I can't get past calling the xgboost function.
I have tried a fairly wide variety of various pointer things as it relates to calling XGBoosterGetModelRaw, but I get SIGSEGV no matter what.
The xgboost api is defined as:
/*!
* \brief save model into binary raw bytes, return header of the array
* user must copy the result out, before next xgboost call
* \param handle handle
* \param out_len the argument to hold the output length
* \param out_dptr the argument to hold the output data pointer
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
bst_ulong *out_len,
const char **out_dptr);
Any ideas? Thanks!
I think it might be because out_dptr is a double pointer in the C API (and marked as const). So making out_dptr a ptr::null instead of null_mut and passing in a mutable reference to it to the C API should work. Something like the following (though I haven't tested it):
/// Returns a `Vec<u8>` with model weights.
pub fn to_vec(&self) -> XGBResult<Vec<u8>> {
debug!("Writing Booster to_vec");
let mut out_len = 0;
let mut out_dptr = ptr::null();
xgb_call!(xgboost_sys::XGBoosterGetModelRaw(self.handle, &mut out_len, &mut out_dptr))?;
let out_ptr_slice = unsafe { slice::from_raw_parts(out_dptr, out_len as usize) };
let out_vec: Vec<u8> = out_ptr_slice.iter()
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(str_ptr).to_bytes().to_owned() })
.collect::<Vec<Vec<u8>>>()
.concat();
Ok(out_vec)
}