distilabel
distilabel copied to clipboard
Question about checkpoint strategy
I want to push my results to huggingface with frequency 2000, like in distilabel 0.6.0:
freq = 2000
dataset_checkpoint = DatasetCheckpoint(path=Path.cwd() / "checkpoint_folder_evol_cn", save_frequency=freq,
strategy = 'hf-hub',
extra_kwargs={"repo_id": 'xxx/xxx', "token": '---' })
data_result = pipeline.generate(dataset=dataset,checkpoint_strategy=dataset_checkpoint,
batch_size=40,num_generations=3,config_name="task2")
however, in distilabel 1.0.3, for push_to_hub,
it says that we can do checkpoint strategy but
it seems i can't do this via push_to_hub
how can i make checkpoint strategy? Thanks a lot! 👍 👍 👍
Hi here @YueWu0301! Indeed to do that you only need to define an intermediate step within your Pipeline
so that the PushToHub
step is called, meaning that it will wait for all the previous batches to be completed, and then it will basically merge those into a datasets.Dataset
and push it to the Hub.
Find an example below, and let me know if you would need further assistance on this issue 👍🏻
from distilabel.llms.groq import GroqLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadHubDataset, PushToHub
from distilabel.steps.tasks import TextGeneration
if __name__ == "__main__":
with Pipeline(name="push-to-hub-checkpoint") as pipeline:
load_hub_dataset = LoadHubDataset(
name="load_dataset",
output_mappings={"prompt": "instruction"},
)
text_generation = TextGeneration(
name="text_generation",
llm=GroqLLM(
model="llama3-70b-8192",
api_key="...", # type: ignore
),
input_batch_size=10,
output_mappings={"model_name": "generation_model"},
)
load_hub_dataset.connect(text_generation)
push_to_hub = PushToHub(name="push_to_hub")
text_generation.connect(push_to_hub)
# More steps here
...
pipeline.run(
parameters={
"load_dataset": {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
"split": "test",
},
"text_generation": {
"llm": {
"generation_kwargs": {
"max_new_tokens": 1024,
"temperature": 0.7,
},
},
},
"push_to_hub": {
"repo_id": "text-generation-groq-checkpoint",
"split": "train",
"private": False,
"token": "...",
},
...
}
)
Also @YueWu0301 note that we've fully refactored distilabel
from 0.Y.Z to 1.0.0; so the code you're using does not match the one you're referencing in the distilabel
docs. So we encourage you to explore the documentation in detail for the v1.0.0 changes and adapt your script to that instead, as the page you pointed out is already distilabel
v1.0.0; and now we're following a DAG approach with an arbitrary number of steps, instead of only generator
and labeller
as before.
Here you can find the latest documentation at https://distilabel.argilla.io/latest, and then navigate to Tutorial
to see all the interfaces we introduce and how to use them!
Hi here @YueWu0301 I'll close this issue for the moment, since it is inactive, and we couldn't reproduce on our end! Let us know if you still experience this issue and we'll re-open the issue! Thanks 👍🏻