swift
swift copied to clipboard
custom.py注册自定义数据集报错
···python #custom.py
from typing import Optional, Tuple
from datasets import Dataset as HfDataset from modelscope import MsDataset
from swift.llm import get_dataset, register_dataset, get_dataset_from_repo from swift.utils import get_logger
logger = get_logger()
class CustomDatasetName: stsb_en = 'stsb-en'
def _preprocess_stsb(dataset: HfDataset) -> HfDataset: prompt = """Task: Based on the given two sentences, provide a similarity score between 0.0 and 5.0. Sentence 1: {text1} Sentence 2: {text2} Similarity score: """ query = [] response = [] for d in dataset: query.append(prompt.format(text1=d['text1'], text2=d['text2'])) response.append(f"{d['label']:.1f}") return HfDataset.from_dict({'query': query, 'response': response})
register_dataset(CustomDatasetName.stsb_en, 'huangjintao/stsb', None, _preprocess_stsb, get_dataset_from_repo)
if name == 'main': # test dataset train_dataset, val_dataset = get_dataset([CustomDatasetName.stsb_en], check_dataset_strategy='warning') print(f'train_dataset: {train_dataset}') print(f'val_dataset: {val_dataset}')
就啥都没干,直接运行官方的custom示例,直接报错找不到'stsb-en'数据集,怎么增加自定义数据集或者自定义的prompt模板呢?