rust-bert
rust-bert copied to clipboard
RemoteResource doesn't allow loading safetensors models
RemoteResource resource provider doesn't preserve file name or extension
let cached_path = CACHE
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
Ok(cached_path)
but Tch-rs requires model path to have safetensor extension to load model file in Safetensors format
fn named_tensors<T: AsRef<std::path::Path>>(
&self,
path: T,
) -> Result<HashMap<String, Tensor>, TchError> {
let named_tensors = match path.as_ref().extension().and_then(|x| x.to_str()) {
Some("bin") | Some("pt") => Tensor::loadz_multi_with_device(&path, self.device),
Some("safetensors") => Tensor::read_safetensors(path),
Some(_) | None => Tensor::load_multi_with_device(&path, self.device),
};
Ok(named_tensors?.into_iter().collect())
}