returnn
returnn copied to clipboard
HuggingFace datasets wrapper
We should be able to use HuggingFace datasets directly in RETURNN.
I guess the most canonical way would be to write a RETURNN Dataset
for this. Maybe derived from CachedDataset2
.
A separate independent more direct PyTorch dataset wrapper might make sense. Or actually I think not needed, as HuggingFace already directly supports this?
I guess the most canonical way would be to write a RETURNN Dataset for this. Maybe derived from CachedDataset2.
I already implemented this as a custom dataset some time ago (see implementation below).
Things to discuss are maybe how to handle tokenisation and other preprocessing steps.
In the current implementation there is either the option to define a map function (see https://huggingface.co/docs/datasets/about_map_batch) in the config file or to provide a preprocessed dataset stored with save_to_disk
.
num_outputs and data_type is currently extracted from the dataset features attribute (which may not always be available).
Another point to discuss is how to handle caching for large datasets (currently the default caching mechanism of hf datasets is used which does not fit very well with our setups).
Code
import numpy
from returnn.datasets.basic import DatasetSeq
from returnn.datasets.cached2 import CachedDataset2
from returnn.util.basic import OptionalNotImplementedError
class HuggingfaceDataset(CachedDataset2):
@staticmethod
def kwargs_update_from_config(config, kwargs):
super().kwargs_update_from_config(config, kwargs)
if 'map_func' in kwargs:
if isinstance(kwargs['map_func'], str):
kwargs['map_func'] = config.typed_value(kwargs['map_func'])
def __init__(self, dataset_opts, map_func=None, map_func_args=None, data_key='data', seq_tag_key='id', features=None,
**kwargs):
super(HuggingfaceDataset, self).__init__(**kwargs)
self._seq_order = None
self.dataset_opts = dataset_opts
if isinstance(map_func, str):
from returnn.config import get_global_config
config = get_global_config(raise_exception=False)
map_func = config.typed_value(map_func)
if map_func_args is not None:
map_func = map_func(**map_func_args)
self.map_func = map_func
self.dataset = None
self.data_key = data_key
self.seq_tag_key = seq_tag_key
self.feature_keys = features
self.data_dtype = {}
def initialize(self):
# Load the dataset
import datasets
if isinstance(self.dataset_opts, dict):
self.dataset = datasets.load_dataset(**self.dataset_opts)
else:
self.dataset = datasets.load_from_disk(self.dataset_opts)
assert isinstance(self.dataset, datasets.Dataset)
if self.map_func is not None:
self.dataset = self.map_func(self.dataset)
if self.feature_keys is None:
self.feature_keys = list(self.dataset.features.keys())
if self.seq_tag_key is not None and self.seq_tag_key in self.feature_keys:
self.feature_keys.remove(self.seq_tag_key)
else:
assert False, "Dataset does not have a seq_tag"
self.dataset.set_format('numpy')
if self.seq_tag_key is not None:
assert self.seq_tag_key in self.dataset.column_names
self.labels = {}
self.num_outputs = {}
for key in self.feature_keys:
feature = self.dataset.features[key]
dtype = None
num_classes = None
spatial_dims = 0
while type(feature) is datasets.features.Sequence:
spatial_dims += 1
if feature.length != -1:
num_classes = feature.length
feature = feature.feature
if type(feature) is datasets.features.ClassLabel:
self.labels[key] = feature.names
dtype = feature.dtype
num_classes = feature.num_classes
elif type(feature) is datasets.features.Value:
dtype = feature.dtype
elif isinstance(feature, (datasets.features.Array2D, datasets.features.Array3D, datasets.features.Array4D)):
dtype = feature.dtype
num_classes = feature.shape[-1]
spatial_dims += len(feature.shape)
else:
assert False, f"Unsupported feature type {type(feature)}"
len_shape = spatial_dims
self.num_outputs[key] = [num_classes, len_shape]
self.data_dtype[key] = dtype
super().initialize()
def get_data_dim(self, key):
if key in self.num_outputs:
return self.num_outputs[key][0]
return super().get_data_dim(key)
def get_data_dtype(self, key):
return self.data_dtype[key]
def _get_seq_len(self, seq_idx):
return len(self.dataset[seq_idx][self.data_key])
@property
def num_seqs(self):
assert self._seq_order is not None, "num_seqs is only known after calling init_seq_order()"
return len(self._seq_order)
def get_tag(self, sorted_seq_idx):
return self.dataset[int(self.get_corpus_seq_idx(sorted_seq_idx))][self.seq_tag_key]
def get_all_tags(self):
return list(self.dataset[self.seq_tag_key])
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""
:param int|None epoch:
:param list[str]|None seq_list: List of sequence tags, to set a predefined order.
:param list[int]|None seq_order: List of corpus sequence indices, to set a predefined order.
:rtype: bool
:returns whether the order changed (True is always safe to return)
"""
super(HuggingfaceDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if seq_order:
self._seq_order = seq_order
# TODO can we return False?
return True
if seq_list:
all_tags = self.get_all_tags()
self._seq_order = [all_tags.index(tag) for tag in seq_list]
# TODO can we return False?
return True
try:
self._seq_order = self.get_seq_order_for_epoch(
epoch=epoch, num_seqs=self.dataset.num_rows, get_seq_len=self._get_seq_len)
except OptionalNotImplementedError:
# only support seq_ordering that need no length here
assert self.seq_ordering in ["default", "reverse", "random"]
self._seq_order = self.get_seq_order_for_epoch(
epoch=epoch, num_seqs=self.dataset.num_rows, get_seq_len=None)
return True
def _collect_single_seq(self, seq_idx):
"""
:param int seq_idx: sorted seq idx
:return:
"""
corpus_seq_idx = self.get_corpus_seq_idx(seq_idx)
def ensure_numpy(x):
if not isinstance(x, numpy.ndarray):
return numpy.array(x)
return x
dataset_item = self.dataset[int(corpus_seq_idx)]
features = {f: ensure_numpy(dataset_item[f]) for f in self.feature_keys}
return DatasetSeq(
seq_idx,
features=features,
targets=None,
seq_tag=dataset_item[self.seq_tag_key]
)
def get_current_seq_order(self):
"""
:rtype: list[int]
"""
assert self._seq_order is not None
return self._seq_order
def get_corpus_seq_idx(self, sorted_seq_idx):
"""
:param int sorted_seq_idx:
:return corpus_seq_idx
:rtype: int
"""
return self._seq_order[sorted_seq_idx]
def can_serialize_data(self, key):
return True
def serialize_data(self, key, data):
if key in self.labels:
return super().serialize_data(key, data)
if isinstance(data, numpy.ndarray):
data = data.tolist()
return data
Related is the Sisyphus job to prepare HuggingFace datasets (https://github.com/rwth-i6/i6_core/pull/253). Doesn't this handle the caching? Ideally we should prepare our dataset wrapper here such that it works properly together with this download preparation job.
@dthulke Can you say some examples what HF datasets you use?
Related is the Sisyphus job to prepare HuggingFace datasets (https://github.com/rwth-i6/i6_core/pull/253). Doesn't this handle the caching? Ideally we should prepare our dataset wrapper here such that it works properly together with this download preparation job.
Yes, this handles the caching of the initial dataset download, but not the caching of the processed version (via dataset.map
). But we could add a separate job for this.
@dthulke Can you say some examples what HF datasets you use?
In RETURNN, I mainly use hf datasets for sequence classification (e.g. sentiment analysis) or sequence tagging task (named entity recognition). For example: https://huggingface.co/datasets/conll2003
In addition, I have a few datasets the I load with custom dataset loading scripts or the default json dataset implementation.
One example, for additional preprocessing (beyond tokenisation) is to include document-level/cross-sentence context or to convert NER labels (given as start end positions) to BIO labels.