Quickest way to get to stacked tensor from input batch
In Rust it seems like there is a few ways to go from a batch of text inputs to an encoded stacked tensor.
If I read an input txt file into a vec it's possible to use the batch_encode() function and take advantage of parallelization. However this still just returns an Encoding type and not a tensor. To get the ids you then have to call get_ids() on every individual encoding therefore looping over everything again, and storing each tensor before stacking at the end.
let path = "/path/to/input_file.txt";
let infile = File::open(path)?;
let mut reader = BufReader::new(infile);
let mut contents: Vec<EncodeInput> = vec![];
for line in reader.lines() {
// create encoder input from each input line;
contents.push(EncodeInput::Single(line.unwrap()));
};
// have to loop over everything to get ids
let batch = tokenizer.encode_batch(contents, true)?;
let mut ids: Vec<&[u32]> = vec![];
for x in batch.iter() {
ids.push(x.get_ids());
}
// Then have to convert to a tensor. This fails anyway since method doesn't cooperate with u32 - even with one sample of &[u32]
let stacked_tensor = Tensor::of_slice(&ids);
Alternatively, I can read the input txt file line by line and do all the above steps one by one and not use any of the batch_encode() functionality.
Is there a better way which I am missing?
As a side question, the get_ids() method returns &[u32] which isn't friendly with the Tensor::of_slice way of creating a tensor.
// Fails
let mut ids: &[u32] = &[2, 44, 55, 3, 0, 0, 0, 0, 0, 0];
let mut tensor = Tensor::of_slice(&ids);
// Works
let mut ids: &[i64] = &[2, 44, 55, 3, 0, 0, 0, 0, 0, 0];
let mut tensor = Tensor::of_slice(&ids);
Indeed, we do not integrate with any downstream solution at the moment and let you do it, as your use-case might be completely different from others. Do you have any suggestions on ways to improve the API without making it specific to your use-case?
Can you also point me to the documentation of the Tensor you mention? I'm really not sure what you are trying to use here.
Sure: https://docs.rs/tch/0.1.7/tch/struct.Tensor.html#method.of_slice
Thank you! As I was expected this method works for any T so the fact that u32 has to be converted is true for a lot of different types. We might be able to provide a way to convert the Encoding somehow down the road, but in the meantime, you will have to convert everything manually and keep the results in your own instances of Vec.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.