accelerate
accelerate copied to clipboard
How I can gather non-tensor objects(i.e. string) for evaluation?
I'm a new user for accelerate. I wondering how I can gather non-tensor objects(i.e. string) when evaluating.
For example:
generated_tokens, labels, choices, labs, dts = accelerator.gather((generated_tokens, labels, choices, labs, dts))
In this case, generated_tokens, labels are tensor and gather
work correctly.
But labs, dts are strings and gather
don't work where I obtain empty list.
How should I deal with that?
Hello @TZWwww, you can have a dataset that gives the index of the sample from the dataset. Gather that index and then retrieve the related textual features for those list of indices.
@TZWwww that is the main recommendation because torch does not currently support gathering non-tensor objects: https://github.com/pytorch/pytorch/issues/62466
An interim option I'm thinking on as well for a utility is encode your string as an ord
and transform them inside your dataset. This lets them become a tensor and you can just decode them back out. E.g.:
import torch
def _num_to_str(nums):
s = ''
for batch in nums:
for char in batch:
s += chr(char)
return s
def _str_to_num(string, batch_size=64):
"Encodes `string` to a decodeable number and breaks it up by `batch_size`"
batch, inner_batch = [], []
for i, char in enumerate(string):
char = ord(char)
inner_batch.append(char)
if (len(inner_batch) == batch_size) or (i == len(string) - 1):
batch.append(inner_batch)
inner_batch = []
return batch
def str_to_tensor(string, batch_size=64) -> torch.tensor:
"""
Encodes `string` to a tensor of shape [1,N,batch_size] where
`batch_size` is the number of characters and `n` is
(len(string)//batch_size) + 1
"""
return torch.tensor(_str_to_num(string), dtype=torch.long)
def tensor_to_str(x:torch.Tensor) -> str:
"""
Decodes `x` to a string. `x` must have been encoded from
`str_to_tensor`
"""
return _num_to_str(x.tolist())
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.