returnn icon indicating copy to clipboard operation
returnn copied to clipboard

HuggingFace datasets wrapper

Open albertz opened this issue 2 years ago • 4 comments

We should be able to use HuggingFace datasets directly in RETURNN.

I guess the most canonical way would be to write a RETURNN Dataset for this. Maybe derived from CachedDataset2.

A separate independent more direct PyTorch dataset wrapper might make sense. Or actually I think not needed, as HuggingFace already directly supports this?

albertz avatar Feb 07 '23 12:02 albertz

I guess the most canonical way would be to write a RETURNN Dataset for this. Maybe derived from CachedDataset2.

I already implemented this as a custom dataset some time ago (see implementation below).

Things to discuss are maybe how to handle tokenisation and other preprocessing steps. In the current implementation there is either the option to define a map function (see https://huggingface.co/docs/datasets/about_map_batch) in the config file or to provide a preprocessed dataset stored with save_to_disk.

num_outputs and data_type is currently extracted from the dataset features attribute (which may not always be available).

Another point to discuss is how to handle caching for large datasets (currently the default caching mechanism of hf datasets is used which does not fit very well with our setups).

Code
import numpy
from returnn.datasets.basic import DatasetSeq
from returnn.datasets.cached2 import CachedDataset2
from returnn.util.basic import OptionalNotImplementedError


class HuggingfaceDataset(CachedDataset2):

  @staticmethod
  def kwargs_update_from_config(config, kwargs):
    super().kwargs_update_from_config(config, kwargs)
    if 'map_func' in kwargs:
      if isinstance(kwargs['map_func'], str):
        kwargs['map_func'] = config.typed_value(kwargs['map_func'])

  def __init__(self, dataset_opts, map_func=None, map_func_args=None, data_key='data', seq_tag_key='id', features=None,
               **kwargs):
    super(HuggingfaceDataset, self).__init__(**kwargs)

    self._seq_order = None

    self.dataset_opts = dataset_opts

    if isinstance(map_func, str):
      from returnn.config import get_global_config
      config = get_global_config(raise_exception=False)
      map_func = config.typed_value(map_func)

    if map_func_args is not None:
      map_func = map_func(**map_func_args)
    self.map_func = map_func
    self.dataset = None
    self.data_key = data_key
    self.seq_tag_key = seq_tag_key

    self.feature_keys = features

    self.data_dtype = {}

  def initialize(self):
    # Load the dataset
    import datasets
    if isinstance(self.dataset_opts, dict):
      self.dataset = datasets.load_dataset(**self.dataset_opts)
    else:
      self.dataset = datasets.load_from_disk(self.dataset_opts)
      assert isinstance(self.dataset, datasets.Dataset)
    if self.map_func is not None:
      self.dataset = self.map_func(self.dataset)
    if self.feature_keys is None:
      self.feature_keys = list(self.dataset.features.keys())
      if self.seq_tag_key is not None and self.seq_tag_key in self.feature_keys:
        self.feature_keys.remove(self.seq_tag_key)
      else:
        assert False, "Dataset does not have a seq_tag"

    self.dataset.set_format('numpy')

    if self.seq_tag_key is not None:
      assert self.seq_tag_key in self.dataset.column_names

    self.labels = {}
    self.num_outputs = {}
    for key in self.feature_keys:
      feature = self.dataset.features[key]
      dtype = None
      num_classes = None
      spatial_dims = 0
      while type(feature) is datasets.features.Sequence:
        spatial_dims += 1
        if feature.length != -1:
          num_classes = feature.length
        feature = feature.feature
      if type(feature) is datasets.features.ClassLabel:
        self.labels[key] = feature.names
        dtype = feature.dtype
        num_classes = feature.num_classes
      elif type(feature) is datasets.features.Value:
        dtype = feature.dtype
      elif isinstance(feature, (datasets.features.Array2D, datasets.features.Array3D, datasets.features.Array4D)):
        dtype = feature.dtype
        num_classes = feature.shape[-1]
        spatial_dims += len(feature.shape)
      else:
        assert False, f"Unsupported feature type {type(feature)}"

      len_shape = spatial_dims
      self.num_outputs[key] = [num_classes, len_shape]

      self.data_dtype[key] = dtype

    super().initialize()

  def get_data_dim(self, key):
    if key in self.num_outputs:
      return self.num_outputs[key][0]
    return super().get_data_dim(key)

  def get_data_dtype(self, key):
    return self.data_dtype[key]

  def _get_seq_len(self, seq_idx):
    return len(self.dataset[seq_idx][self.data_key])

  @property
  def num_seqs(self):
    assert self._seq_order is not None, "num_seqs is only known after calling init_seq_order()"
    return len(self._seq_order)

  def get_tag(self, sorted_seq_idx):
    return self.dataset[int(self.get_corpus_seq_idx(sorted_seq_idx))][self.seq_tag_key]

  def get_all_tags(self):
    return list(self.dataset[self.seq_tag_key])

  def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
    """
    :param int|None epoch:
    :param list[str]|None seq_list: List of sequence tags, to set a predefined order.
    :param list[int]|None seq_order: List of corpus sequence indices, to set a predefined order.
    :rtype: bool
    :returns whether the order changed (True is always safe to return)
    """
    super(HuggingfaceDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)

    if seq_order:
      self._seq_order = seq_order
      # TODO can we return False?
      return True

    if seq_list:
      all_tags = self.get_all_tags()
      self._seq_order = [all_tags.index(tag) for tag in seq_list]
      # TODO can we return False?
      return True

    try:
      self._seq_order = self.get_seq_order_for_epoch(
        epoch=epoch, num_seqs=self.dataset.num_rows, get_seq_len=self._get_seq_len)
    except OptionalNotImplementedError:
      # only support seq_ordering that need no length here
      assert self.seq_ordering in ["default", "reverse", "random"]
      self._seq_order = self.get_seq_order_for_epoch(
        epoch=epoch, num_seqs=self.dataset.num_rows, get_seq_len=None)

    return True

  def _collect_single_seq(self, seq_idx):
    """
    :param int seq_idx: sorted seq idx
    :return:
    """
    corpus_seq_idx = self.get_corpus_seq_idx(seq_idx)

    def ensure_numpy(x):
      if not isinstance(x, numpy.ndarray):
        return numpy.array(x)
      return x

    dataset_item = self.dataset[int(corpus_seq_idx)]
    features = {f: ensure_numpy(dataset_item[f]) for f in self.feature_keys}
    return DatasetSeq(
      seq_idx,
      features=features,
      targets=None,
      seq_tag=dataset_item[self.seq_tag_key]
    )

  def get_current_seq_order(self):
    """
    :rtype: list[int]
    """
    assert self._seq_order is not None
    return self._seq_order

  def get_corpus_seq_idx(self, sorted_seq_idx):
    """
    :param int sorted_seq_idx:
    :return corpus_seq_idx
    :rtype: int
    """
    return self._seq_order[sorted_seq_idx]

  def can_serialize_data(self, key):
    return True

  def serialize_data(self, key, data):
    if key in self.labels:
      return super().serialize_data(key, data)
    if isinstance(data, numpy.ndarray):
      data = data.tolist()
    return data

dthulke avatar Feb 08 '23 10:02 dthulke

Related is the Sisyphus job to prepare HuggingFace datasets (https://github.com/rwth-i6/i6_core/pull/253). Doesn't this handle the caching? Ideally we should prepare our dataset wrapper here such that it works properly together with this download preparation job.

albertz avatar Feb 08 '23 11:02 albertz

@dthulke Can you say some examples what HF datasets you use?

albertz avatar Feb 08 '23 11:02 albertz

Related is the Sisyphus job to prepare HuggingFace datasets (https://github.com/rwth-i6/i6_core/pull/253). Doesn't this handle the caching? Ideally we should prepare our dataset wrapper here such that it works properly together with this download preparation job.

Yes, this handles the caching of the initial dataset download, but not the caching of the processed version (via dataset.map). But we could add a separate job for this.

@dthulke Can you say some examples what HF datasets you use?

In RETURNN, I mainly use hf datasets for sequence classification (e.g. sentiment analysis) or sequence tagging task (named entity recognition). For example: https://huggingface.co/datasets/conll2003

In addition, I have a few datasets the I load with custom dataset loading scripts or the default json dataset implementation.

One example, for additional preprocessing (beyond tokenisation) is to include document-level/cross-sentence context or to convert NER labels (given as start end positions) to BIO labels.

dthulke avatar Feb 08 '23 11:02 dthulke