rust-bert icon indicating copy to clipboard operation
rust-bert copied to clipboard

RemoteResource doesn't allow loading safetensors models

Open zaytsev opened this issue 1 year ago • 0 comments

RemoteResource resource provider doesn't preserve file name or extension

        let cached_path = CACHE
            .cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
        Ok(cached_path)

but Tch-rs requires model path to have safetensor extension to load model file in Safetensors format

    fn named_tensors<T: AsRef<std::path::Path>>(
        &self,
        path: T,
    ) -> Result<HashMap<String, Tensor>, TchError> {
        let named_tensors = match path.as_ref().extension().and_then(|x| x.to_str()) {
            Some("bin") | Some("pt") => Tensor::loadz_multi_with_device(&path, self.device),
            Some("safetensors") => Tensor::read_safetensors(path),
            Some(_) | None => Tensor::load_multi_with_device(&path, self.device),
        };
        Ok(named_tensors?.into_iter().collect())
    }

zaytsev avatar Feb 29 '24 13:02 zaytsev