Daft
Daft copied to clipboard
Model loads after each completed partition
Given the embedding udf below, the model re-loads and is reinitializing after each write and completed partition / parquet file:
import daft
import torch
import numpy as np
from typing import Optional
BATCH_SIZE = 1
NUM_GPUS = 1
return_dtype=daft.DataType.list(daft.DataType.float64())
@daft.udf(return_dtype=return_dtype, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE)
class STEmbeddingUDF:
def __init__(
self,
model_name: str,
device: str = "cuda",
convert_to_tensor: bool = False,
torch_dtype: torch.dtype = torch.float16,
set_seq_len: bool = True,
max_seq_length: Optional[int] = 2048,
):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device=device,
model_kwargs={"torch_dtype": torch_dtype}
)
if set_seq_len:
self.model.max_seq_length = max_seq_length
self.convert_to_tensor = convert_to_tensor
self.device = device
def __call__(self, text_col: daft.DataFrame) -> daft.DataFrame:
embeddings = self.model.encode(
text_col.to_pylist(),
batch_size=BATCH_SIZE,
convert_to_tensor=self.convert_to_tensor,
device=self.device,
)
if self.convert_to_tensor is not True:
embeddings = embeddings[0].astype(np.float64)
return [embeddings]
...
df = daft.read_parquet(f"hf://datasets/name")
processor = STEmbeddingUDF
data = col("text")
df = df.with_column("embeddings, processor(data))
df.to_parquet("embeddings")
Model reloads after completing and writing one partition/file:
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:49<00:00, 7.09s/it]
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:47<00:00, 6.82s/it]
Loading checkpoint shards: 29%|████████████████████████████████▎ | 2/7 [00:13<00:32, 6.41s/it]
Any input would be greatly appreciated.