Daft icon indicating copy to clipboard operation
Daft copied to clipboard

Model loads after each completed partition

Open conceptofmind opened this issue 5 months ago • 9 comments

Given the embedding udf below, the model re-loads and is reinitializing after each write and completed partition / parquet file:

import daft
import torch
import numpy as np
from typing import Optional

BATCH_SIZE = 1
NUM_GPUS = 1
return_dtype=daft.DataType.list(daft.DataType.float64())

@daft.udf(return_dtype=return_dtype, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE)
class STEmbeddingUDF:
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        convert_to_tensor: bool = False,
        torch_dtype: torch.dtype = torch.float16,
        set_seq_len: bool = True,
        max_seq_length: Optional[int] = 2048,
    ):
        from sentence_transformers import SentenceTransformer
        
        self.model = SentenceTransformer(
            model_name,
            device=device,
            model_kwargs={"torch_dtype": torch_dtype}
        )

        if set_seq_len:
            self.model.max_seq_length = max_seq_length

        self.convert_to_tensor = convert_to_tensor
        self.device = device

    def __call__(self, text_col: daft.DataFrame) -> daft.DataFrame:
        embeddings = self.model.encode(
            text_col.to_pylist(), 
            batch_size=BATCH_SIZE,
            convert_to_tensor=self.convert_to_tensor,
            device=self.device,
        )
        if self.convert_to_tensor is not True:
            embeddings = embeddings[0].astype(np.float64) 
        return [embeddings]

...

df = daft.read_parquet(f"hf://datasets/name")

processor = STEmbeddingUDF
data = col("text")
df = df.with_column("embeddings, processor(data))

df.to_parquet("embeddings")

Model reloads after completing and writing one partition/file:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:49<00:00,  7.09s/it]
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:47<00:00,  6.82s/it]
Loading checkpoint shards:  29%|████████████████████████████████▎                                                                                | 2/7 [00:13<00:32,  6.41s/it]

Any input would be greatly appreciated.

conceptofmind avatar Sep 21 '24 17:09 conceptofmind