PaddleNLP icon indicating copy to clipboard operation
PaddleNLP copied to clipboard

[Improvement Request] 简化数据集加载逻辑并改进文档支持

Open natureLanguageQing opened this issue 5 months ago • 1 comments

Feature request

标题:

描述:

在使用 paddlenlp 进行训练和微调时,我发现 pretrain 和 sft 部分的代码实现存在一些问题,尤其是在数据集加载逻辑方面。当前实现包含了大量复杂的代码来处理数据集路径的检测和加载,这不仅影响了代码的可读性,也使得用户难以理解和维护。

当前实现问题:

    if data_args.dataset_name_or_path is None:
        raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})")
    elif (
        os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json"))
        or os.path.exists(os.path.join(data_args.dataset_name_or_path, "dev.json"))
        or os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant.json"))
    ):
        if training_args.do_train or quant_args.do_qat:
            train_ds = load_dataset(
                "json",
                data_files=os.path.join(data_args.dataset_name_or_path, "train.json"),
                lazy=data_args.lazy,
            )[0]
        else:
            train_ds = None
        if training_args.do_eval:
            dev_ds = load_dataset(
                "json",
                data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"),
                lazy=data_args.lazy,
            )[0]
        else:
            dev_ds = None
        if quant_args.do_ptq or quant_args.do_gptq:
            if os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant.json")):
                ptq_ds = load_dataset(
                    "json",
                    data_files=os.path.join(data_args.dataset_name_or_path, "quant.json"),
                    lazy=data_args.lazy,
                )[0]
            elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")):
                ptq_ds = load_dataset(
                    "json",
                    data_files=os.path.join(data_args.dataset_name_or_path, "train.json"),
                    lazy=data_args.lazy,
                )[0]
                logger.info(
                    f"Not found quant.json in {data_args.dataset_name_or_path}. Set train dataset as PTQ calibration dataset."
                )
            else:
                raise ValueError(
                    f"Quant strategy requires quant.json or train.json in {data_args.dataset_name_or_path}"
                )
        else:
            ptq_ds = None
    elif (
        os.path.exists(os.path.join(data_args.dataset_name_or_path, "train"))
        or os.path.exists(os.path.join(data_args.dataset_name_or_path, "dev"))
        or os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant"))
    ):
        import glob

        if training_args.do_train or quant_args.do_qat:
            train_ds = load_dataset(
                "json",
                data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")),
                lazy=data_args.lazy,
            )[0]
        else:
            train_ds = None
        if training_args.do_eval:
            dev_ds = load_dataset(
                "json",
                data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")),
                lazy=data_args.lazy,
            )[0]
        else:
            dev_ds = None
        if quant_args.do_ptq or quant_args.do_gptq:
            if os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant")):
                ptq_ds = load_dataset(
                    "json",
                    data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "quant", "*.json")),
                    lazy=data_args.lazy,
                )[0]
            elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")):
                ptq_ds = load_dataset(
                    "json",
                    data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")),
                    lazy=data_args.lazy,
                )[0]
                logger.info(
                    f"Not found quant.json in {data_args.dataset_name_or_path}. Set train dataset as PTQ calibration dataset."
                )
            else:
                raise ValueError(f"Quant strategy requires quant or train folder in {data_args.dataset_name_or_path}")
        else:
            ptq_ds = None
    else:
        if training_args.do_train or quant_args.do_qat:
            train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0]
        else:
            train_ds = None
        if training_args.do_eval:
            dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0]
        else:
            dev_ds = None
        if quant_args.do_ptq or quant_args.do_gptq:
            ptq_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0]
            logger.info("Set train dataset as PTQ calibration dataset.")
        else:
            ptq_ds = None

数据集加载逻辑复杂:

当前代码处理数据集路径和加载的逻辑非常繁琐。这种复杂性不仅使代码难以理解,而且增加了维护的难度。 文档支持不足:

当前文档中未包括对数据集加载逻辑的详细解读,使得用户很难理解如何正确配置和使用数据集。 建议的改进:

简化数据集加载逻辑:

目标: 减少代码的复杂性,通过更简洁的逻辑处理数据集加载,提升代码的可读性和维护性。 实现建议: 统一数据集格式和路径规范,减少路径检查和数据集加载的复杂度。 改进文档支持:

目标: 提供清晰的文档解释数据集加载逻辑和要求。 实现建议: 在文档中详细说明数据集文件的格式、路径要求,以及如何正确配置数据集。 添加使用示例和说明,帮助用户理解如何设置数据集并解决常见问题。 期望的改进:

简化数据集加载: 通过精简代码逻辑,提升代码可读性和维护性。 文档改进: 提供详细的文档说明,帮助用户理解数据集要求和配置。 感谢团队对项目的持续投入和改进。希望这些建议能对 paddlenlp 的发展有所帮助。

Motivation

当前代码处理数据集路径和加载的逻辑非常繁琐。这种复杂性不仅使代码难以理解,而且增加了维护的难度。 文档支持不足:

当前文档中未包括对数据集加载逻辑的详细解读,使得用户很难理解如何正确配置和使用数据集。

Your contribution

目标: 提供清晰的文档解释数据集加载逻辑和要求。 实现建议: 在文档中详细说明数据集文件的格式、路径要求,以及如何正确配置数据集。 添加使用示例和说明,帮助用户理解如何设置数据集并解决常见问题。 期望的改进:

简化数据集加载: 通过精简代码逻辑,提升代码可读性和维护性。 文档改进: 提供详细的文档说明,帮助用户理解数据集要求和配置。

natureLanguageQing avatar Aug 26 '24 07:08 natureLanguageQing