Paddle icon indicating copy to clipboard operation
Paddle copied to clipboard

[Auto Parallel] Add Pipeline Memory Estimator

Open AndSonder opened this issue 1 year ago • 7 comments

PR Category

Auto Parallel

PR Types

Improvements

Description

PipelineMemoryEstimator 类的主要功能是根据 program 来估计显存占用情况。在添加一个 program 的时候,会根据 program 来估计显存占用情况。与估计单个 program 的显存不同的是,PipelineMemoryEstimator 类还会维护一个 type_to_skip_gc_vars 字典,用来存储会被后续program使用的变量。

该类主要为 zb-vpp 自动调度服务,用于估计 program 的显存占用

测试环境,4卡1080Ti,Llama2,PaddleNLP develop,1F1B 编排模式,测试脚本如下:

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama_auto_static_dp2sharding2mp2pp2_vpp2"
# rm -rf output/$task_name/  # ckpt is saved in 'output/''
rm -rf "output/$task_name""_log"

# export PARALLEL_CROSS_ENTROPY=true
export FLAGS_call_stack_level=4
export PYTHONPATH=../../../:$PYTHONPATH
export GLOG_v=0
# export FLAGS_pir_apply_inplace_pass=0
export FLAGS_log_memory_stats=1

python -u -m paddle.distributed.launch \
    --gpus "0,1,2,3" \
    --log_dir "output/$task_name""_log" \
    run_pretrain_auto_static.py \
    --model_type "llama" \
    --model_name_or_path "facebook/llama-7b" \
    --tokenizer_name_or_path "facebook/llama-7b" \
    --input_dir "../data" \
    --output_dir "output/$task_name" \
    --split 949,50,1 \
    --max_seq_length 2048 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --use_flash_attention 0 \
    --use_fused_rms_norm 0 \
    --fp16 0 \
    --fp16_opt_level "O2"  \
    --scale_loss 1024 \
    --pipeline_parallel_degree  4 \
    --tensor_parallel_degree 1 \
    --pipeline_schedule_mode "1F1B" \
    --learning_rate 0.0001 \
    --min_learning_rate 0.00001 \
    --max_steps 10 \
    --save_steps 5000 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --max_grad_norm 1.0 \
    --logging_steps 1 \
    --dataloader_num_workers 1 \
    --eval_steps 1000 \
    --report_to "visualdl" \
    --disable_tqdm true \
    --continue_training 0 \
    --recompute 1 \
    --recompute_granularity full \
    --do_train \
    --do_eval \
    --device "gpu" \
    --data_impl "mmap" \
    --enable_auto_parallel 1 \
    --sharding_parallel_degree 1 \
    --sharding "stage1" \
    --num_hidden_layers 4 \

想要复现测试结果时候需要对 1f1b 编排代码进行 hack,在 1f1b 类中重写 apply_single_impl

    def _apply_single_impl(self, main_program, startup_program, context):
        """
        The shared process is implemented in this function and new subclass only need
        to implement two interfaces above, 'create_job_list' and 'partial_programs'.
        """
        job_types, sub_programs = self._partial_programs(main_program)

        enable_pir_in_executor = paddle.framework.get_flags(
            "FLAGS_enable_pir_in_executor"
        )['FLAGS_enable_pir_in_executor']
        if enable_pir_in_executor:
            shadow_var_between_sub_programs(sub_programs)

        for i in range(len(job_types)):
            logger.debug(
                f"sub_program type: {job_types[i]}, sum_program:\n{sub_programs[i]}"
            )

        jobs = self._create_job_list()
        for job in jobs:
            print(job.type())
        
        dist_context = self.get_attr("dist_context")
        if dist_context is None:
            raise ValueError("dist_context is None.")
        
        type_to_program = dict(zip(job_types, sub_programs))
        
        print(type_to_program)

        mem_tool = PipelineMemoryEstimator()
        program_types = ["forward", "backward", "optimizer"]
        mem_tool.set_program_skip_gc_vars(type_to_program, program_types)
        
        for type in ["forward", "backward"]:
            increase_mem, max_mem = mem_tool.estimate_memory(type_to_program[type], type, dist_context)
            print(f"Type: {type}, increase_mem: {increase_mem}, max_mem: {max_mem}, param size: {mem_tool.get_program_param_size(type)}")

        print(f"skip_gc_vars: {mem_tool.type_to_skip_gc_vars}")
        print(f"base memory allocated: {paddle.device.cuda.memory_allocated(self.get_attr('pp_stage'))}")
        exit(0)

预估结果与模型实际运行结果如下(单位 MB)运行后显存变化表示以 program 开始运行时候的显存作为基准,这个 program 运行完之后显存的变化。运行中 max 值的意思是以 program 开始运行时候的显存为基准,program 过程中的最大显存占用。

pp4, gradient accumulation 8, 开启 recompute,batch1, num_hidden_layers 4

image

pp4, gradient accumulation 8, 不开启 recompute,batch1, num_hidden_layers 4

image

pp2, mp2, gradient accumulation 8, 不开启 recompute,batch2, num_hidden_layers 4

image

相关 Issue:

  • https://github.com/PaddlePaddle/Paddle/issues/62666

AndSonder avatar Apr 11 '24 03:04 AndSonder

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle-bot[bot] avatar Apr 11 '24 03:04 paddle-bot[bot]

Sorry to inform you that e2d78ee's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Apr 19 '24 03:04 paddle-ci-bot[bot]

Sorry to inform you that 3f5dcdf's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar May 08 '24 03:05 paddle-ci-bot[bot]

Sorry to inform you that 6c3e2b1's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar May 16 '24 03:05 paddle-ci-bot[bot]

Sorry to inform you that f31d13e's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar May 24 '24 03:05 paddle-ci-bot[bot]

Sorry to inform you that c36e00d's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Jun 01 '24 03:06 paddle-ci-bot[bot]

Sorry to inform you that f798e7c's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Jun 21 '24 03:06 paddle-ci-bot[bot]

后续可以尝试增加多流分析,进一步完善显存预估。

heavyrain-lzy avatar Jul 19 '24 03:07 heavyrain-lzy

后续可以尝试增加多流分析,进一步完善显存预估。

好的 ~

AndSonder avatar Jul 19 '24 04:07 AndSonder