datasets
datasets copied to clipboard
Issue when wanting to split in memory a cached dataset
Describe the bug
In the 'train_test_split' method of the Dataset class (defined datasets/arrow_dataset.py), if 'self.cache_files' is not empty, then, regarding the input parameters 'train_indices_cache_file_name' and 'test_indices_cache_file_name', if they are None, we modify them to make them not None, to see if we can just provide back / work from cached data. But if we can't provide cached data, we move on with the call to the method, except those two values are not None anymore, which will conflict with the use of the 'keep_in_memory' parameter down the line. Indeed, at some point we end up calling the 'select' method, and if 'keep_in_memory' is True, since the value of this method's parameter 'indices_cache_file_name' is now not None anymore, an exception is raised, whose message is "Please use either 'keep_in_memory' or 'indices_cache_file_name' but not both.". Because of that, it's impossible to perform a train / test split of a cached dataset while requesting that the result not be cached. Which is inconvenient when one is just performing experiments, with no intention of caching the result.
Aside from this being inconvenient, the code which lead up to that situation seems simply wrong to me: the input variable should not be modified so as to change the user's intention just to perform a test, if that test can fail and respecting the user's intention is necessary to proceed in that case. To fix this, I suggest to use other variables / other variable names, in order to host the value(s) needed to perform the test, so as not to change the originally input values needed by the rest of the method's code.
Also, I don't see why an exception should be raised when the 'select' method is called with both 'keep_in_memory'=True and 'indices_cache_file_name'!=None: should the use of 'keep_in_memory' not prevail anyway, specifying that the user does not want to perform caching, and so making irrelevant the value of 'indices_cache_file_name'? This is indeed what happens when we look further in the code, in the '_select_with_indices_mapping' method: when 'keep_in_memory' is True, then the value of indices_cache_file_name does not matter, the data will be written to a stream buffer anyway. Hence I suggest to remove the raising of exception in those circumstances. Notably, to remove the raising of it in the 'select', '_select_with_indices_mapping', 'shuffle' and 'map' methods.
Steps to reproduce the bug
import datasets
def generate_examples():
for i in range(10):
yield {"id": i}
dataset_ = datasets.Dataset.from_generator(
generate_examples,
keep_in_memory=False,
)
dataset_.train_test_split(
test_size=3,
shuffle=False,
keep_in_memory=True,
train_indices_cache_file_name=None,
test_indices_cache_file_name=None,
)
Expected behavior
The result of the above code should be a DatasetDict instance.
Instead, we get the following exception stack:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[3], line 1
----> 1 dataset_.train_test_split(
2 test_size=3,
3 shuffle=False,
4 keep_in_memory=True,
5 train_indices_cache_file_name=None,
6 test_indices_cache_file_name=None,
7 )
File ~/Work/Developments/datasets/src/datasets/arrow_dataset.py:528, in transmit_format.<locals>.wrapper(*args, **kwargs)
521 self_format = {
522 "type": self._format_type,
523 "format_kwargs": self._format_kwargs,
524 "columns": self._format_columns,
525 "output_all_columns": self._output_all_columns,
526 }
527 # apply actual function
--> 528 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
529 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
530 # re-apply format to the output
File ~/Work/Developments/datasets/src/datasets/fingerprint.py:511, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
507 validate_fingerprint(kwargs[fingerprint_name])
509 # Call actual function
--> 511 out = func(dataset, *args, **kwargs)
513 # Update fingerprint of in-place transforms + update in-place history of transforms
515 if inplace: # update after calling func so that the fingerprint doesn't change if the function fails
File ~/Work/Developments/datasets/src/datasets/arrow_dataset.py:4428, in Dataset.train_test_split(self, test_size, train_size, shuffle, stratify_by_column, seed, generator, keep_in_memory, load_from_cache_file, train_indices_cache_file_name, test_indices_cache_file_name, writer_batch_size, train_new_fingerprint, test_new_fingerprint)
4425 test_indices = permutation[:n_test]
4426 train_indices = permutation[n_test : (n_test + n_train)]
-> 4428 train_split = self.select(
4429 indices=train_indices,
4430 keep_in_memory=keep_in_memory,
4431 indices_cache_file_name=train_indices_cache_file_name,
4432 writer_batch_size=writer_batch_size,
4433 new_fingerprint=train_new_fingerprint,
4434 )
4435 test_split = self.select(
4436 indices=test_indices,
4437 keep_in_memory=keep_in_memory,
(...)
4440 new_fingerprint=test_new_fingerprint,
4441 )
4443 return DatasetDict({"train": train_split, "test": test_split})
File ~/Work/Developments/datasets/src/datasets/arrow_dataset.py:528, in transmit_format.<locals>.wrapper(*args, **kwargs)
521 self_format = {
522 "type": self._format_type,
523 "format_kwargs": self._format_kwargs,
524 "columns": self._format_columns,
525 "output_all_columns": self._output_all_columns,
526 }
527 # apply actual function
--> 528 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
529 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
530 # re-apply format to the output
File ~/Work/Developments/datasets/src/datasets/fingerprint.py:511, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
507 validate_fingerprint(kwargs[fingerprint_name])
509 # Call actual function
--> 511 out = func(dataset, *args, **kwargs)
513 # Update fingerprint of in-place transforms + update in-place history of transforms
515 if inplace: # update after calling func so that the fingerprint doesn't change if the function fails
File ~/Work/Developments/datasets/src/datasets/arrow_dataset.py:3679, in Dataset.select(self, indices, keep_in_memory, indices_cache_file_name, writer_batch_size, new_fingerprint)
3645 """Create a new dataset with rows selected following the list/array of indices.
3646
3647 Args:
(...)
3676 ```
3677 """
3678 if keep_in_memory and indices_cache_file_name is not None:
-> 3679 raise ValueError("Please use either `keep_in_memory` or `indices_cache_file_name` but not both.")
3681 if len(self.list_indexes()) > 0:
3682 raise DatasetTransformationNotAllowedError(
3683 "Using `.select` on a dataset with attached indexes is not allowed. You can first run `.drop_index() to remove your index and then re-add it."
3684 )
ValueError: Please use either `keep_in_memory` or `indices_cache_file_name` but not both.
Environment info
-
datasets
version: 2.11.1.dev0 - Platform: Linux-5.4.236-1-MANJARO-x86_64-with-glibc2.2.5
- Python version: 3.8.12
- Huggingface_hub version: 0.13.3
- PyArrow version: 11.0.0
- Pandas version: 2.0.0
EDIT: Now with a pull request to fix this here