distilabel icon indicating copy to clipboard operation
distilabel copied to clipboard

Question about checkpoint strategy

Open YueWu0301 opened this issue 9 months ago • 2 comments

I want to push my results to huggingface with frequency 2000, like in distilabel 0.6.0:

freq = 2000
dataset_checkpoint = DatasetCheckpoint(path=Path.cwd() / "checkpoint_folder_evol_cn", save_frequency=freq,
                                       strategy = 'hf-hub',
                                       extra_kwargs={"repo_id": 'xxx/xxx', "token": '---' })


data_result = pipeline.generate(dataset=dataset,checkpoint_strategy=dataset_checkpoint,
                                batch_size=40,num_generations=3,config_name="task2")

however, in distilabel 1.0.3, for push_to_hub, image it says that we can do checkpoint strategy but image it seems i can't do this via push_to_hub how can i make checkpoint strategy? Thanks a lot! 👍 👍 👍

YueWu0301 avatar Apr 28 '24 07:04 YueWu0301

Hi here @YueWu0301! Indeed to do that you only need to define an intermediate step within your Pipeline so that the PushToHub step is called, meaning that it will wait for all the previous batches to be completed, and then it will basically merge those into a datasets.Dataset and push it to the Hub.

Find an example below, and let me know if you would need further assistance on this issue 👍🏻

from distilabel.llms.groq import GroqLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadHubDataset, PushToHub
from distilabel.steps.tasks import TextGeneration

if __name__ == "__main__":
    with Pipeline(name="push-to-hub-checkpoint") as pipeline:
        load_hub_dataset = LoadHubDataset(
            name="load_dataset",
            output_mappings={"prompt": "instruction"},
        )

        text_generation = TextGeneration(
            name="text_generation",
            llm=GroqLLM(
                model="llama3-70b-8192",
                api_key="...",  # type: ignore
            ),
            input_batch_size=10,
            output_mappings={"model_name": "generation_model"},
        )
        load_hub_dataset.connect(text_generation)

        push_to_hub = PushToHub(name="push_to_hub")
        text_generation.connect(push_to_hub)

        # More steps here
        ...

    pipeline.run(
        parameters={
            "load_dataset": {
                "repo_id": "distilabel-internal-testing/instruction-dataset-mini",
                "split": "test",
            },
            "text_generation": {
                "llm": {
                    "generation_kwargs": {
                        "max_new_tokens": 1024,
                        "temperature": 0.7,
                    },
                },
            },
            "push_to_hub": {
                "repo_id": "text-generation-groq-checkpoint",
                "split": "train",
                "private": False,
                "token": "...",
            },
            ...
        }
    )

alvarobartt avatar Apr 28 '24 11:04 alvarobartt

Also @YueWu0301 note that we've fully refactored distilabel from 0.Y.Z to 1.0.0; so the code you're using does not match the one you're referencing in the distilabel docs. So we encourage you to explore the documentation in detail for the v1.0.0 changes and adapt your script to that instead, as the page you pointed out is already distilabel v1.0.0; and now we're following a DAG approach with an arbitrary number of steps, instead of only generator and labeller as before.

Here you can find the latest documentation at https://distilabel.argilla.io/latest, and then navigate to Tutorial to see all the interfaces we introduce and how to use them!

alvarobartt avatar Apr 28 '24 11:04 alvarobartt

Hi here @YueWu0301 I'll close this issue for the moment, since it is inactive, and we couldn't reproduce on our end! Let us know if you still experience this issue and we'll re-open the issue! Thanks 👍🏻

alvarobartt avatar May 09 '24 12:05 alvarobartt