Faiss GPU index cannot be serialised when passed to trainer
Describe the bug
I am working on a retrieval project and encountering I have encountered two issues in the hugging face faiss integration:
- I am trying to pass in a dataset with a faiss index to the Huggingface trainer. The code works for a cpu faiss index, but doesn't for a gpu one, getting error:
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/transformers/trainer.py", line 1543, in train
return inner_training_loop(
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in _inner_training_loop
train_dataloader = self.get_train_dataloader()
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/transformers/trainer.py", line 831, in get_train_dataloader
train_dataset = self._remove_unused_columns(train_dataset, description="training")
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/transformers/trainer.py", line 725, in _remove_unused_columns
return dataset.remove_columns(ignored_columns)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 592, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 557, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/fingerprint.py", line 481, in wrapper
out = func(dataset, *args, **kwargs)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2146, in remove_columns
dataset = copy.deepcopy(self)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/faiss/__init__.py", line 556, in index_getstate
return {"this": serialize_index(self).tobytes()}
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/faiss/__init__.py", line 1607, in serialize_index
write_index(index, writer)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/faiss/swigfaiss.py", line 9843, in write_index
return _swigfaiss.write_index(*args)
RuntimeError: Error in void faiss::write_index(const faiss::Index*, faiss::IOWriter*) at /project/faiss/faiss/impl/index_write.cpp:590: don't know how to serialize this type of index
The index was created with the add_faiss_index method
train_dataset.add_faiss_index(
column='embeddings',
index_name='embeddings',
string_factory=faiss_index_string,
train_size=config.faiss_train_size,
device=0, # Use -1 for CPU, or specify GPU device ID
faiss_verbose=True
)
- Athough faiss is written to be compatible on the gpu for searching https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU I am getting error when trying to use the hugggingface code to do the search on gpu. This seems to be caused by this line https://github.com/huggingface/datasets/blob/f9975f636542df7f95c27065ea93147440d690b7/src/datasets/search.py#L376 producing error
total_scores, total_examples = self.dataset.get_nearest_examples_batch('embeddings', embeddings, k=self.k)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/search.py", line 773, in get_nearest_examples_batch
total_scores, total_indices = self.search_batch(index_name, queries, k, **kwargs)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/search.py", line 727, in search_batch
return self._indexes[index_name].search_batch(queries, k, **kwargs)
File "/users/rubman/.conda/envs/protein_npt_env/lib/python3.10/site-packages/datasets/search.py", line 376, in search_batch
if not queries.flags.c_contiguous:
AttributeError: 'Tensor' object has no attribute 'flags'
Steps to reproduce the bug
train_dataset.add_faiss_index(
column='embeddings',
index_name='embeddings',
string_factory=faiss_index_string,
train_size=config.faiss_train_size,
device=0, # Use -1 for CPU, or specify GPU device ID
faiss_verbose=True
)
Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
train_dataset.get_nearest_examples_batch('embeddings', embeddings, k=self.k)
Expected behavior
I would expect the faiss database code to be gpu compatible
Environment info
huggingface Version: 2.16.1
Hi ! make sure your query embeddings are numpy arrays, not torch tensors ;)
Hi Quentin, not sure how that solves the problem number 1. I am trying to pass on a dataset with a faiss gpu for training to the standard trainer but getting this serialisation error. What is a workaround this? I do not want to remove the faiss index, as I would want to use it to create batches of retrieved samples from the dataset. Thanks in advance for your help!
Issue number one seems to be an issue with FAISS indexes not being compatible with copy.deepcopy.
Maybe you try to not remove the columns, e.g. by passing remove_unused_columns=False