PaddleNLP icon indicating copy to clipboard operation
PaddleNLP copied to clipboard

【Hackathon 8th No.32】 Adam-mini 精调算法复现

Open megemini opened this issue 8 months ago • 16 comments

Before submitting

  • [x] Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • [ ] Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

New features

PR changes

Others

Description

关联题目:https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_8th/%E3%80%90Hackathon_8th%E3%80%91%E4%B8%AA%E4%BA%BA%E6%8C%91%E6%88%98%E8%B5%9B%E2%80%94%E5%A5%97%E4%BB%B6%E5%BC%80%E5%8F%91%E4%BB%BB%E5%8A%A1%E5%90%88%E9%9B%86.md#no32-adam-mini-%E7%B2%BE%E8%B0%83%E7%AE%97%E6%B3%95%E5%A4%8D%E7%8E%B0

实现思路:

  • _add_moments_pows 中分为 4 种情况添加 accumulator 。
  • _append_optimize_op 中分为 3 种情况进行优化,之所以少一种情况,是因为后两者计算逻辑应该是一样的,只是 broadcast 有区别。

目前开发阶段:

  • 通过修改 tests/llm/test_pretrain.py 的配置文件,使用 adamw_mini 可以通过测试。
  • 目前只测试覆盖了 _add_moments_pows 中 3、4 与 _append_optimize_op 中 3 ,这几部分的逻辑。
  • 目前逻辑不涉及分布式计算。

目前测试中不能使用 llama 模型,应该不是本 PR 引入的问题,以下为原代码中的注释:

@parameterized_class(
    ["model_dir"],
    [
        # ["llama"], @skip("Skip and wait to fix.")
        # ["qwen"], @skip("Skip and wait to fix.")
        ["qwen2"],
        ["gpt"],
    ],
)

后续开发计划:

  • 需要测试覆盖 _add_moments_pows 中 1、2 与 _append_optimize_op 1、2 的逻辑
  • 需要测试 llama3 模型的内存使用情况
  • 需要对齐优化后的 loss

目前主要还没想好后续 llama3 咋测试哩?有木有适合的配置文件?llm/config/llama/sft_argument.json 这个吗?

@DrownFish19

感谢!:)

megemini avatar Apr 15 '25 08:04 megemini

Thanks for your contribution!

paddle-bot[bot] avatar Apr 15 '25 08:04 paddle-bot[bot]

Codecov Report

:x: Patch coverage is 86.82635% with 22 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 46.91%. Comparing base (e7420b1) to head (f0f8864). :warning: Report is 74 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/utils/optimizer.py 86.82% 22 Missing :warning:
Additional details and impacted files
@@             Coverage Diff             @@
##           develop   #10413      +/-   ##
===========================================
+ Coverage    46.83%   46.91%   +0.08%     
===========================================
  Files          800      800              
  Lines       132933   133062     +129     
===========================================
+ Hits         62254    62422     +168     
+ Misses       70679    70640      -39     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Apr 15 '25 08:04 codecov[bot]

@megemini 可以参考llm/config/llama/sft_argument.json进行配置,其中的分布式配置可能需要根据你的环境进行修改,重点验证显存的降低情况即可。

DrownFish19 avatar Apr 18 '25 05:04 DrownFish19

Update 20250424

1. 算法实现

首先,原来的 AdamWMini 算法是不是有问题?原算法这里:

image

是不是应该写成:

p += (mom1 / denom) * (-(lr / (1.0 - beta1_pow)))

然后,正如之前所说,这里分 block 之后,

  • _add_moments_pows 中分为 4 种情况添加 accumulator 。
  • _append_optimize_op 中分为 3 种情况进行优化。

2. 显存分析

这里的测试环境是:aistudio 的 32g v100 环境。

测试命令:

  • python -u run_finetune.py ./config/llama/lora_argument.json

sft_argument.json 显存装不下,所以用的 lora。

这里实现的 AdamWMini 算法主要优化了 moment2 的显存占用,在原 AdamW 算法优化过程中,moment2shapeparam 一致,占用显存情况如下:


[2025-04-24 13:24:56,538] [    INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4724, interval_samples_per_second: 2.1169, interval_steps_per_second: 2.1169, ppl: 9.008836689307477, progress_or_epoch: 0.012

如果不做优化,即,去掉 moment1moment2 的显存占用,则:


[2025-04-24 13:30:41,708] [    INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.238787412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.429, interval_samples_per_second: 2.3307, interval_steps_per_second: 2.3307, ppl: 9.008836689307477, progress_or_epoch: 0.012

可得,AdamWmoment1/moment2 占用显存为: 15.395037412643433 - 15.238787412643433 = 0.15625

当使用 AdamWMini 算法后,由于大部分需要优化的参数,其 moment2shape 都是 [param.shape[0], 1],占用显存情况如下:


[2025-04-24 14:40:40,786] [    INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.321604490280151, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.6338, interval_samples_per_second: 1.5779, interval_steps_per_second: 1.5779, ppl: 9.008836689307477, progress_or_epoch: 0.012

可得,优化后的显存占用: (15.395037412643433 - 15.321604490280151) / 0.15625 = 0.469970703

由于这里只是使用 lora 进行 finetune,所以显存占用的优化情况并不明显。如果有条件的话可以看看 pretrain 的优化情况 ~ 不过 aistudio 好像搞不定 ... ...

另外,对比原 AdamWMini 部分 block 的显存使用情况:


[2025-04-24 14:18:27,326] [    INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.31701922416687, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.6674, interval_samples_per_second: 1.4984, interval_steps_per_second: 1.4984, ppl: 9.008836689307477, progress_or_epoch: 0.012

优化后的显存占用: (15.395037412643433 - 15.31701922416687) / 0.15625 = 0.499316406

之所以原来的 AdamWMini 占用显存更小,是因为,原 AdamWMini 不分 block,所有的 moment2shape 都是 [1],但,这与 AdamWMini 作者实现的算法不一致,所以这里不做参考了。

3. 算法的问题

在测试比对 AdamW 的优化精度的时候发现一个问题,这里不使用 c++,而是直接用 python 更新 param 好像不太对 ~~~

在算法中打印参数,插入以下代码:


        if param.name == 'lo_ra_linear_223.w_2':    
            print(">>>>>", master_weight is None)
            print('>'*20, 
                'm1', moment1.sum(), 'm2', moment2.sum(), 
                'b1', beta1_pow_acc.sum(), 'b2', beta2_pow_acc.sum(), 
                'master_weight', master_weight.sum(), 'param', param_and_grad[0].sum(), 'grad', param_and_grad[1].sum())

AdamW 算法,


[2025-04-24 15:19:23,985] [    INFO] -   Total num train samples = 500
[2025-04-24 15:19:23,991] [   DEBUG] -   Number of trainable parameters = 20,971,520 (per device)
W0424 15:19:24.498792 301809 multiply_fwd_func.cc:76] got different data type, run type promotion automatically, this may cause data type been changed.
W0424 15:19:24.506394 301809 gpu_resources.cc:306] WARNING: device: 0. The installed Paddle is compiled with CUDNN 8.9, but CUDNN version in your machine is 8.9, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDNN version.
Found inf or nan, current scale is: 32768.0, decrease to: 32768.0*0.5
[2025-04-24 15:19:25,264] [ WARNING] - optimizer not run, scale_before: 32768.0, scale_after: 16384.0
[2025-04-24 15:19:25,268] [    INFO] - loss: 4.01541185, learning_rate: 0.0, global_step: 1, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 1.2765, interval_samples_per_second: 0.7834, interval_steps_per_second: 0.7834, ppl: 55.44612618781211, progress_or_epoch: 0.002
Found inf or nan, current scale is: 16384.0, decrease to: 16384.0*0.5
[2025-04-24 15:19:25,657] [ WARNING] - optimizer not run, scale_before: 16384.0, scale_after: 8192.0
[2025-04-24 15:19:25,661] [    INFO] - loss: 2.49645019, learning_rate: 0.0, global_step: 2, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.3921, interval_samples_per_second: 2.5504, interval_steps_per_second: 2.5504, ppl: 12.139325087796644, progress_or_epoch: 0.004
Found inf or nan, current scale is: 8192.0, decrease to: 8192.0*0.5
[2025-04-24 15:19:26,041] [ WARNING] - optimizer not run, scale_before: 8192.0, scale_after: 4096.0
[2025-04-24 15:19:26,044] [    INFO] - loss: 1.44202387, learning_rate: 0.0, global_step: 3, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.3833, interval_samples_per_second: 2.609, interval_steps_per_second: 2.609, ppl: 4.229246606564234, progress_or_epoch: 0.006
Found inf or nan, current scale is: 4096.0, decrease to: 4096.0*0.5
[2025-04-24 15:19:26,426] [ WARNING] - optimizer not run, scale_before: 4096.0, scale_after: 2048.0
[2025-04-24 15:19:26,429] [    INFO] - loss: 5.85809898, learning_rate: 0.0, global_step: 4, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.3853, interval_samples_per_second: 2.5952, interval_steps_per_second: 2.5952, ppl: 350.05804374322315, progress_or_epoch: 0.008
Found inf or nan, current scale is: 2048.0, decrease to: 2048.0*0.5
[2025-04-24 15:19:26,831] [ WARNING] - optimizer not run, scale_before: 2048.0, scale_after: 1024.0
[2025-04-24 15:19:26,836] [    INFO] - loss: 1.94955564, learning_rate: 0.0, global_step: 5, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4064, interval_samples_per_second: 2.4606, interval_steps_per_second: 2.4606, ppl: 7.025565006800808, progress_or_epoch: 0.01
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.89999998) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99900001) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.03613281)
[2025-04-24 15:19:27,365] [    INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5298, interval_samples_per_second: 1.8877, interval_steps_per_second: 1.8877, ppl: 9.008836689307477, progress_or_epoch: 0.012
Found inf or nan, current scale is: 1024.0, decrease to: 1024.0*0.5
[2025-04-24 15:19:27,869] [ WARNING] - optimizer not run, scale_before: 1024.0, scale_after: 512.0
[2025-04-24 15:19:27,875] [    INFO] - loss: 3.99188352, learning_rate: 0.0003, global_step: 7, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5089, interval_samples_per_second: 1.9651, interval_steps_per_second: 1.9651, ppl: 54.15679877261727, progress_or_epoch: 0.014
Found inf or nan, current scale is: 512.0, decrease to: 512.0*0.5
[2025-04-24 15:19:28,377] [ WARNING] - optimizer not run, scale_before: 512.0, scale_after: 256.0
[2025-04-24 15:19:28,381] [    INFO] - loss: 2.8999517, learning_rate: 0.0003, global_step: 8, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5069, interval_samples_per_second: 1.9729, interval_steps_per_second: 1.9729, ppl: 18.173267579420518, progress_or_epoch: 0.016
Found inf or nan, current scale is: 256.0, decrease to: 256.0*0.5
[2025-04-24 15:19:28,875] [ WARNING] - optimizer not run, scale_before: 256.0, scale_after: 128.0
[2025-04-24 15:19:28,879] [    INFO] - loss: 4.1729598, learning_rate: 0.0003, global_step: 9, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.498, interval_samples_per_second: 2.0082, interval_steps_per_second: 2.0082, ppl: 64.90728064956862, progress_or_epoch: 0.018
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00361453) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00000435) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.80999994) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99800104) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.01861572)
[2025-04-24 15:19:29,404] [    INFO] - loss: 1.64721501, learning_rate: 0.0002994, global_step: 10, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5251, interval_samples_per_second: 1.9044, interval_steps_per_second: 1.9044, ppl: 5.192498614806668, progress_or_epoch: 0.02
Found inf or nan, current scale is: 128.0, decrease to: 128.0*0.5
[2025-04-24 15:19:29,912] [ WARNING] - optimizer not run, scale_before: 128.0, scale_after: 64.0
[2025-04-24 15:19:29,917] [    INFO] - loss: 3.22934556, learning_rate: 0.0002994, global_step: 11, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5125, interval_samples_per_second: 1.9513, interval_steps_per_second: 1.9513, ppl: 25.263118364607866, progress_or_epoch: 0.022
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00511430) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00000622) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.72899991) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99700308) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       -0.13641964) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       -0.13635254) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.03591919)
[2025-04-24 15:19:30,438] [    INFO] - loss: 3.001441, learning_rate: 0.0002988, global_step: 12, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.521, interval_samples_per_second: 1.9193, interval_steps_per_second: 1.9193, ppl: 20.114501045532172, progress_or_epoch: 0.024
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00819490) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00000712) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.65609992) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99600607) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       -0.22575189) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       -0.22570801) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       -0.05255127)
[2025-04-24 15:19:30,962] [    INFO] - loss: 2.2637279, learning_rate: 0.0002982, global_step: 13, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5247, interval_samples_per_second: 1.9058, interval_steps_per_second: 1.9058, ppl: 9.618880636929768, progress_or_epoch: 0.026

而在 AdamWMini 中,


[2025-04-24 15:16:38,172] [    INFO] -   Total num train samples = 500
[2025-04-24 15:16:38,175] [   DEBUG] -   Number of trainable parameters = 20,971,520 (per device)
W0424 15:16:38.704833 296812 multiply_fwd_func.cc:76] got different data type, run type promotion automatically, this may cause data type been changed.
W0424 15:16:38.712059 296812 gpu_resources.cc:306] WARNING: device: 0. The installed Paddle is compiled with CUDNN 8.9, but CUDNN version in your machine is 8.9, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDNN version.
Found inf or nan, current scale is: 32768.0, decrease to: 32768.0*0.5
[2025-04-24 15:16:39,594] [ WARNING] - optimizer not run, scale_before: 32768.0, scale_after: 16384.0
[2025-04-24 15:16:39,599] [    INFO] - loss: 4.01541185, learning_rate: 0.0, global_step: 1, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 1.4227, interval_samples_per_second: 0.7029, interval_steps_per_second: 0.7029, ppl: 55.44612618781211, progress_or_epoch: 0.002
Found inf or nan, current scale is: 16384.0, decrease to: 16384.0*0.5
[2025-04-24 15:16:40,021] [ WARNING] - optimizer not run, scale_before: 16384.0, scale_after: 8192.0
[2025-04-24 15:16:40,024] [    INFO] - loss: 2.49645019, learning_rate: 0.0, global_step: 2, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.425, interval_samples_per_second: 2.3529, interval_steps_per_second: 2.3529, ppl: 12.139325087796644, progress_or_epoch: 0.004
Found inf or nan, current scale is: 8192.0, decrease to: 8192.0*0.5
[2025-04-24 15:16:40,433] [ WARNING] - optimizer not run, scale_before: 8192.0, scale_after: 4096.0
[2025-04-24 15:16:40,436] [    INFO] - loss: 1.44202387, learning_rate: 0.0, global_step: 3, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4121, interval_samples_per_second: 2.4269, interval_steps_per_second: 2.4269, ppl: 4.229246606564234, progress_or_epoch: 0.006
Found inf or nan, current scale is: 4096.0, decrease to: 4096.0*0.5
[2025-04-24 15:16:40,840] [ WARNING] - optimizer not run, scale_before: 4096.0, scale_after: 2048.0
[2025-04-24 15:16:40,843] [    INFO] - loss: 5.85809898, learning_rate: 0.0, global_step: 4, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4068, interval_samples_per_second: 2.4583, interval_steps_per_second: 2.4583, ppl: 350.05804374322315, progress_or_epoch: 0.008
Found inf or nan, current scale is: 2048.0, decrease to: 2048.0*0.5
[2025-04-24 15:16:41,249] [ WARNING] - optimizer not run, scale_before: 2048.0, scale_after: 1024.0
[2025-04-24 15:16:41,252] [    INFO] - loss: 1.94955564, learning_rate: 0.0, global_step: 5, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.409, interval_samples_per_second: 2.445, interval_steps_per_second: 2.445, ppl: 7.025565006800808, progress_or_epoch: 0.01
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.89999998) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99900001) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.03613281)
[2025-04-24 15:16:41,970] [    INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.7184, interval_samples_per_second: 1.3921, interval_steps_per_second: 1.3921, ppl: 9.008836689307477, progress_or_epoch: 0.012
Found inf or nan, current scale is: 1024.0, decrease to: 1024.0*0.5
[2025-04-24 15:16:42,371] [ WARNING] - optimizer not run, scale_before: 1024.0, scale_after: 512.0
[2025-04-24 15:16:42,374] [    INFO] - loss: 3.99188352, learning_rate: 0.0003, global_step: 7, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4037, interval_samples_per_second: 2.4771, interval_steps_per_second: 2.4771, ppl: 54.15679877261727, progress_or_epoch: 0.014
Found inf or nan, current scale is: 512.0, decrease to: 512.0*0.5
[2025-04-24 15:16:42,773] [ WARNING] - optimizer not run, scale_before: 512.0, scale_after: 256.0
[2025-04-24 15:16:42,776] [    INFO] - loss: 2.8999517, learning_rate: 0.0003, global_step: 8, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4018, interval_samples_per_second: 2.4887, interval_steps_per_second: 2.4887, ppl: 18.173267579420518, progress_or_epoch: 0.016
Found inf or nan, current scale is: 256.0, decrease to: 256.0*0.5
[2025-04-24 15:16:43,173] [ WARNING] - optimizer not run, scale_before: 256.0, scale_after: 128.0
[2025-04-24 15:16:43,176] [    INFO] - loss: 4.1729598, learning_rate: 0.0003, global_step: 9, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4001, interval_samples_per_second: 2.4994, interval_steps_per_second: 2.4994, ppl: 64.90728064956862, progress_or_epoch: 0.018
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00362164) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00000072) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.80999994) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99800104) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.01861572)
[2025-04-24 15:16:43,868] [    INFO] - loss: 1.64721501, learning_rate: 0.0002994, global_step: 10, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.6926, interval_samples_per_second: 1.4438, interval_steps_per_second: 1.4438, ppl: 5.192498614806668, progress_or_epoch: 0.02
Found inf or nan, current scale is: 128.0, decrease to: 128.0*0.5
[2025-04-24 15:16:44,266] [ WARNING] - optimizer not run, scale_before: 128.0, scale_after: 64.0
[2025-04-24 15:16:44,269] [    INFO] - loss: 26.33701515, learning_rate: 0.0002994, global_step: 11, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4004, interval_samples_per_second: 2.4975, interval_steps_per_second: 2.4975, ppl: 274170263505.28873, progress_or_epoch: 0.022
Found inf or nan, current scale is: 64.0, decrease to: 64.0*0.5
[2025-04-24 15:16:44,670] [ WARNING] - optimizer not run, scale_before: 64.0, scale_after: 32.0
[2025-04-24 15:16:44,673] [    INFO] - loss: 25.63891411, learning_rate: 0.0002994, global_step: 12, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4043, interval_samples_per_second: 2.4735, interval_steps_per_second: 2.4735, ppl: 136407710588.60149, progress_or_epoch: 0.024
>>>>> False
>>>>>>>>>>>>>>>>>>>> m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00512069) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       0.00000083) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.72899991) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.99700308) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       -1044.12109375) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       -1044.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True,
       0.)
[2025-04-24 15:16:45,375] [    INFO] - loss: 21.70650673, learning_rate: 0.0002988, global_step: 13, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.7015, interval_samples_per_second: 1.4254, interval_steps_per_second: 1.4254, ppl: 2673105467.692815, progress_or_epoch: 0.026

可以看到,参数的更新明显出现问题,与 c++ 的结果不一样 ~

这里尝试使用 in-place 的算子也没法与 AdamW 对齐,感觉像是 python 的 tensor 更新有问题?还请帮忙看一下~~~

  • 注 : 这里测试的时候,是在 AdamW 算法中屏蔽 c++ 实现,插入 python 实现,因此两者显存占用一致。

@DrownFish19 还请帮忙确认一下,感谢!!!

附:

  • lora_argument.json 配置文件

{
    "model_name_or_path": "meta-llama/Meta-Llama-3-8B",
    "dataset_name_or_path": "./data",
    "output_dir": "./checkpoints/lora_ckpts",
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "per_device_eval_batch_size": 1,
    "eval_accumulation_steps":16,
    "num_train_epochs": 1,
    "learning_rate": 3e-04,
    "warmup_steps": 1,
    "logging_steps": 1,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "src_length": 1024,
    "max_length": 2048,
    "bf16": false,
    "fp16": true,
    "fp16_opt_level": "O2",
    "do_train": true,
    "do_eval": true,
    "disable_tqdm": true,
    "load_best_model_at_end": true,
    "eval_with_do_generation": false,
    "metric_for_best_model": "accuracy",
    "recompute": true,
    "save_total_limit": 1,
    "tensor_parallel_degree": 1,
    "pipeline_parallel_degree": 1,
    "sharding": "stage1",
    "lora": true,
    "zero_padding": false,
    "use_flash_attention": false,
    "unified_checkpoint": true,
    "pissa": false,
    "use_mora": false,
    "optim": "adamw_mini"
  }

megemini avatar Apr 24 '25 07:04 megemini

@DrownFish19 @luotao1 这个算法还搞不?🫠

megemini avatar May 27 '25 07:05 megemini

外部是否有python实现的版本?能否对齐python实现?

DrownFish19 avatar Jun 04 '25 04:06 DrownFish19

外部是否有python实现的版本?能否对齐python实现?

参考的 https://github.com/zyushun/Adam-mini/blob/main/adam_mini/adam_mini.py ,外部的 python 版本都是基于 torch 的,跟 paddle 细节还是有些不一样,比如 torch 用 state 存储变量,paddle 用 accumulators。

能否对齐?简单说:精度对齐不了。

原因上面也说了,有可能是 paddle 自己实现的 tensor 更新有问题,跟算法本身可能关系不大。

目前看,显存精简可以实现,精度问题还需要 paddle 内部排查一下 ~

megemini avatar Jun 04 '25 04:06 megemini

这个更新公式,p += (mom1 / denom) * (-(lr / (1.0 - beta1_pow)))和原算法有区别吗?

DrownFish19 avatar Jun 04 '25 05:06 DrownFish19

外部是否有python实现的版本?能否对齐python实现?

参考的 https://github.com/zyushun/Adam-mini/blob/main/adam_mini/adam_mini.py ,外部的 python 版本都是基于 torch 的,跟 paddle 细节还是有些不一样,比如 torch 用 state 存储变量,paddle 用 accumulators。

实现不同而已,计算过程应该是相似的,能否验证在单次更新下存在的精度误差范围?

DrownFish19 avatar Jun 04 '25 05:06 DrownFish19

这个更新公式,p += (mom1 / denom) * (-(lr / (1.0 - beta1_pow)))和原算法有区别吗?

原来那个写错了,第一个参数是 mom1 ,而不应该是 moment1

megemini avatar Jun 04 '25 05:06 megemini

外部是否有python实现的版本?能否对齐python实现?

参考的 https://github.com/zyushun/Adam-mini/blob/main/adam_mini/adam_mini.py ,外部的 python 版本都是基于 torch 的,跟 paddle 细节还是有些不一样,比如 torch 用 state 存储变量,paddle 用 accumulators。

实现不同而已,计算过程应该是相似的,能否验证在单次更新下存在的精度误差范围?

现在问题是,paddle 的 c艹 与 python 实现对不齐,python 实现的优化器 tensor 更新可能有问题(比如,用 python 实现 adamw 对比原来的 adamw 优化器) ~ 所以,是否与 torch 对齐暂时还无从谈起 ~

megemini avatar Jun 04 '25 05:06 megemini

Update 20250606

  • 已经对齐 torch 的 Adam_mini 算法 ~

以下是测试代码:


import paddle
import torch
import numpy as np
import matplotlib.pyplot as plt
from adam_mini import Adam_mini
from optimizer import AdamWMini

# Set random seeds for reproducibility
SEED = 1
np.random.seed(SEED)
paddle.seed(SEED)
torch.manual_seed(SEED)
STEPS = 200

# Set default devices
paddle.set_device('gpu:0')
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SimpleTransformerPaddle(paddle.nn.Layer):
    def __init__(self, dim=2048, n_heads=32):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        
        # Embedding layer
        self.embd = paddle.nn.Embedding(100, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # Query/Key/Value projections
        self.wq = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        self.wk = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        self.wv = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # Attention projection
        self.wo = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # MLP layers
        self.mlp = paddle.nn.Sequential(
            paddle.nn.Linear(dim, 4*dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02)),
            paddle.nn.GELU(),
            paddle.nn.Linear(4*dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        )
        
        # Output layer
        self.lm_head = paddle.nn.Linear(dim, 100, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # Bias parameters
        self.bias = paddle.create_parameter([dim], dtype='float32', default_initializer=paddle.nn.initializer.Constant(value=0.0))

    def forward(self, input_ids):
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]
        
        # Embedding
        hidden_states = self.embd(input_ids)  # [batch_size, seq_len, dim]

        # print('1='*10)
        # print(input_ids.shape, input_ids.sum())
        # print(hidden_states.shape, hidden_states.sum())
        
        # Query/Key/Value projections and reshape for multi-head attention
        query = self.wq(hidden_states)  # [batch_size, seq_len, dim]
        key = self.wk(hidden_states)    # [batch_size, seq_len, dim]
        value = self.wv(hidden_states)  # [batch_size, seq_len, dim]
        
        # print('2='*10)
        # print(query.shape, query.sum())
        # print(key.shape, key.sum())
        # print(value.shape, value.sum())

        # Reshape to [batch_size, seq_len, n_heads, head_dim]
        query = query.reshape([batch_size, seq_len, self.n_heads, self.head_dim])
        key = key.reshape([batch_size, seq_len, self.n_heads, self.head_dim])
        value = value.reshape([batch_size, seq_len, self.n_heads, self.head_dim])

        # Transpose to [batch_size, n_heads, seq_len, head_dim]
        query = query.transpose([0, 2, 1, 3])
        key = key.transpose([0, 2, 1, 3])
        value = value.transpose([0, 2, 1, 3])
        
        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn_weights = paddle.matmul(query * scale, key.transpose([0, 1, 3, 2]))
        attn_weights = paddle.nn.functional.softmax(attn_weights, axis=-1)
        
        # print('3='*10)
        # print(scale)
        # print(attn_weights.shape, attn_weights.sum())

        # Apply attention to values
        attn_output = paddle.matmul(attn_weights, value)  # [batch_size, n_heads, seq_len, head_dim]
        
        # Reshape back to [batch_size, seq_len, dim]
        attn_output = attn_output.transpose([0, 2, 1, 3])
        attn_output = attn_output.reshape([batch_size, seq_len, self.dim])
        
        # Attention output projection
        attn_output = self.wo(attn_output)
        
        # print('4='*10)
        # print(attn_output.shape, attn_output.sum())

        # Feed forward
        feed_forward = self.mlp(attn_output)
        
        # print('5='*10)
        # print(feed_forward.shape, feed_forward.sum())

        # Output
        output = self.lm_head(feed_forward + self.bias)

        # print('6='*10)
        # print(output.shape, output.sum())

        return output


class SimpleTransformerTorch(torch.nn.Module):
    def __init__(self, dim=2048, n_heads=32):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        
        # Embedding layer
        self.embd = torch.nn.Embedding(100, dim)
        torch.nn.init.normal_(self.embd.weight, mean=0.0, std=0.02)
        
        # Query/Key/Value projections
        self.wq = torch.nn.Linear(dim, dim)
        self.wk = torch.nn.Linear(dim, dim)
        self.wv = torch.nn.Linear(dim, dim)
        torch.nn.init.normal_(self.wq.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.wk.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.wv.weight, mean=0.0, std=0.02)
        
        # Attention projection
        self.wo = torch.nn.Linear(dim, dim)
        torch.nn.init.normal_(self.wo.weight, mean=0.0, std=0.02)
        
        # MLP layers
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(dim, 4*dim),
            torch.nn.GELU(),
            torch.nn.Linear(4*dim, dim)
        )
        torch.nn.init.normal_(self.mlp[0].weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.mlp[2].weight, mean=0.0, std=0.02)
        
        # Output layer
        self.lm_head = torch.nn.Linear(dim, 100)
        torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.02)
        
        # Bias parameters
        self.bias = torch.nn.Parameter(torch.zeros(dim))

    def forward(self, input_ids):
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]
        
        # Embedding
        hidden_states = self.embd(input_ids)  # [batch_size, seq_len, dim]
        
        # print('1='*10)
        # print(input_ids.shape, input_ids.sum())
        # print(hidden_states.shape, hidden_states.sum())

        # Query/Key/Value projections and reshape for multi-head attention
        query = self.wq(hidden_states)  # [batch_size, seq_len, dim]
        key = self.wk(hidden_states)    # [batch_size, seq_len, dim]
        value = self.wv(hidden_states)  # [batch_size, seq_len, dim]
        
        # print('2='*10)
        # print(query.shape, query.sum())
        # print(key.shape, key.sum())
        # print(value.shape, value.sum())

        # Reshape to [batch_size, seq_len, n_heads, head_dim]
        query = query.view(batch_size, seq_len, self.n_heads, self.head_dim)
        key = key.view(batch_size, seq_len, self.n_heads, self.head_dim)
        value = value.view(batch_size, seq_len, self.n_heads, self.head_dim)
        
        # Transpose to [batch_size, n_heads, seq_len, head_dim]
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        
        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn_weights = torch.matmul(query * scale, key.transpose(-2, -1))
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
        
        # print('3='*10)
        # print(scale)
        # print(attn_weights.shape, attn_weights.sum())

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, value)  # [batch_size, n_heads, seq_len, head_dim]
        
        # Reshape back to [batch_size, seq_len, dim]
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(batch_size, seq_len, self.dim)
        
        # Attention output projection
        attn_output = self.wo(attn_output)
        
        # print('4='*10)
        # print(attn_output.shape, attn_output.sum())

        # Feed forward
        feed_forward = self.mlp(attn_output)
        
        # print('5='*10)
        # print(feed_forward.shape, feed_forward.sum())

        # Output
        output = self.lm_head(feed_forward + self.bias)

        # print('6='*10)
        # print(output.shape, output.sum())

        return output

def generate_data(batch_size=32, seq_len=64, vocab_size=100):
    x = np.random.randint(0, vocab_size, size=(batch_size, seq_len))
    y = np.random.randint(0, vocab_size, size=(batch_size, seq_len))
    return x, y

train_data = []
for _ in range(STEPS):
    x_np, y_np = generate_data()
    train_data.append((x_np, y_np))

def train_model(config, steps=STEPS):
    losses_both = {}

    model_t = SimpleTransformerTorch(dim=config['dim'], n_heads=config['n_heads']).to(torch_device)
    model_p = SimpleTransformerPaddle(dim=config['dim'], n_heads=config['n_heads'])
    
    # Copy parameters from PyTorch model to PaddlePaddle model
    for (name_t, param_t), (name_p, param_p) in zip(model_t.named_parameters(), model_p.named_parameters()):
        print(f"\nCopying: {name_t} -> {name_p}")
        print(f"Shapes before: {param_t.shape} -> {param_p.shape}")
        
        param_numpy = param_t.detach().cpu().numpy()
        
        # For linear layers' weights, we need to transpose
        is_weight = any(x in name_t for x in ['weight', 'w_0'])
        is_linear = (
            'mlp' in name_t or  # MLP layer
            'linear' in name_p.lower() or  # Regular linear layers
            any(x in name_t for x in ['wq', 'wk', 'wv', 'wo', 'lm_head'])  # Special layers
        )
        
        print(f"is_weight: {is_weight}, is_linear: {is_linear}")
        
        if is_weight and is_linear:
            print(f"Transposing parameter")
            param_numpy = param_numpy.T
            print(f"Transposed shape: {param_numpy.shape}")
        
        paddle_shape = list(param_p.shape)
        numpy_shape = list(param_numpy.shape)
        if paddle_shape != numpy_shape:
            raise ValueError(f"Shape mismatch after processing: Paddle shape {paddle_shape} != Numpy shape {numpy_shape}")
            
        param_p.set_value(paddle.to_tensor(param_numpy))
        print(f"Shapes after: {param_t.shape} -> {param_p.shape}")
        print("-" * 50)

    print("Testing AdamWMini (Paddle)...")
    # Train Paddle model
    criterion_p = paddle.nn.CrossEntropyLoss()
    model_p.train()
    
    optimizer_p = AdamWMini(
        named_parameters=model_p.named_parameters(),
        learning_rate=config['lr'],
        beta1=config['beta1'],
        beta2=config['beta2'],
        epsilon=config['epsilon'],
        weight_decay=config['weight_decay'],
        dim=config['dim'],
        n_heads=config['n_heads']
    )
    
    losses = []
    for step in range(steps):
        x_np, y_np = train_data[step]
        x = paddle.to_tensor(x_np, dtype='int64', place='gpu:0')
        y = paddle.to_tensor(y_np, dtype='int64', place='gpu:0')
        
        out = model_p(x)  # [batch_size, seq_len, vocab_size]
        out = out.reshape([-1, out.shape[-1]])  # [batch_size*seq_len, vocab_size]
        y = y.reshape([-1])  # [batch_size*seq_len]
        loss = criterion_p(out, y)
        model_p.clear_gradients()
        loss.backward()
        optimizer_p.step()
        losses.append(float(loss.numpy()))
        
        if (step+1) % 10 == 0:
            print(f'step {step+1}, Paddle Loss: {float(loss.numpy()):.4f}')
            
    losses_both['paddle'] = losses

    print("Testing Adam_mini (PyTorch)...")
    # Train Torch model
    criterion_t = torch.nn.CrossEntropyLoss()
    model_t.train()
    
    optimizer_t = Adam_mini(
        named_parameters = model_t.named_parameters(),
        lr=config['lr'],
        betas= (config['beta1'], config['beta2']),
        eps=config['epsilon'],
        weight_decay=config['weight_decay'],
        dim=config['dim'],
        n_heads=config['n_heads']
    )
    
    losses = []
    for step in range(steps):
        x_np, y_np = train_data[step]
        x = torch.tensor(x_np, dtype=torch.long, device=torch_device)
        y = torch.tensor(y_np, dtype=torch.long, device=torch_device)
        
        out = model_t(x)  # [batch_size, seq_len, vocab_size]
        out = out.reshape(-1, out.shape[-1])  # [batch_size*seq_len, vocab_size]
        y = y.reshape(-1)  # [batch_size*seq_len]
        loss = criterion_t(out, y)
        optimizer_t.zero_grad()
        loss.backward()
        optimizer_t.step()
        losses.append(float(loss.detach().cpu().numpy()))
        
        if (step+1) % 10 == 0:
            print(f'step {step+1}, Torch Loss: {float(loss.detach().cpu().numpy()):.4f}')
            
    losses_both['torch'] = losses
    
    return losses_both

# 测试配置
configs = [
    {
        'name': 'base_config',
        'lr': 1e-3,
        'beta1': 0.9,
        'beta2': 0.999,
        'epsilon': 1e-8,
        'weight_decay': 0,
        'dim': 2048,
        'n_heads': 32
    },
]

# 运行比较
results = {}
for config in configs:
    print(f"\nTesting config: {config['name']}")
    
    loss = train_model(config)

    results[config['name']] = {
        'Adam_mini': loss['torch'],
        'AdamWMini': loss['paddle']
    }


以下是输出日志:


--------------------------------------------------
Testing AdamWMini (Paddle)...

Adam-mini found blocks:
- 1 embedding layers
- 1 output layers
- 2 Query and Key layers
- 1 Value layers
- 1 Attention projection layers
- 2 MLP layers

step 10, Paddle Loss: 4.6115
step 20, Paddle Loss: 4.6548
step 30, Paddle Loss: 4.6168
step 40, Paddle Loss: 4.6138
step 50, Paddle Loss: 4.6065
step 60, Paddle Loss: 4.6093
step 70, Paddle Loss: 4.6064
step 80, Paddle Loss: 4.6087
step 90, Paddle Loss: 4.6076
step 100, Paddle Loss: 4.6072
step 110, Paddle Loss: 4.6042
step 120, Paddle Loss: 4.6072
step 130, Paddle Loss: 4.6068
step 140, Paddle Loss: 4.6055
step 150, Paddle Loss: 4.6075
step 160, Paddle Loss: 4.6057
step 170, Paddle Loss: 4.6077
step 180, Paddle Loss: 4.6053
step 190, Paddle Loss: 4.6052
step 200, Paddle Loss: 4.6066
Testing Adam_mini (PyTorch)...
Adam-mini found the param block with name: bias torch.Size([2048])
Adam-mini found the param block with name: embd.weight torch.Size([100, 2048])
Adam-mini found the param block with name: wq.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wq.bias torch.Size([2048])
Adam-mini found the param block with name: wk.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wk.bias torch.Size([2048])
Adam-mini found the param block with name: wv.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wv.bias torch.Size([2048])
Adam-mini found the param block with name: wo.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wo.bias torch.Size([2048])
Adam-mini found the param block with name: mlp.0.weight torch.Size([8192, 2048])
Adam-mini found the param block with name: mlp.0.bias torch.Size([8192])
Adam-mini found the param block with name: mlp.2.weight torch.Size([2048, 8192])
Adam-mini found the param block with name: mlp.2.bias torch.Size([2048])
Adam-mini found the param block with name: lm_head.weight torch.Size([100, 2048])
Adam-mini found the param block with name: lm_head.bias torch.Size([100])
Adam-mini found 1 embedding layers, 1 output layers; 2 Querys and Keys;  1 Values;  1 attn_proj;  2 MLPs;
step 10, Torch Loss: 4.6139
step 20, Torch Loss: 4.6210
step 30, Torch Loss: 4.6074
step 40, Torch Loss: 4.6114
step 50, Torch Loss: 4.6053
step 60, Torch Loss: 4.6104
step 70, Torch Loss: 4.6065
step 80, Torch Loss: 4.6106
step 90, Torch Loss: 4.6071
step 100, Torch Loss: 4.6088
step 110, Torch Loss: 4.6064
step 120, Torch Loss: 4.6086
step 130, Torch Loss: 4.6075
step 140, Torch Loss: 4.6084
step 150, Torch Loss: 4.6106
step 160, Torch Loss: 4.6081
step 170, Torch Loss: 4.6112
step 180, Torch Loss: 4.6073
step 190, Torch Loss: 4.6076
step 200, Torch Loss: 4.6072


说明几点:

  • 测试了 200 步,loss 基本一致。由于是测试数据,所以不需要看收敛情况,对比一下 paddle 与 torch 的 loss 就可以。两者基本都是在 4.6 多一点。
  • 没有做到 loss 完全一致的原因:仔细比对了一下正向与反向的日志,导致不一致主要是在处理极小的数值时 padde 与 torch 会有偏差 ~ sqrt 之后更明显了,所以两者没法做到完全一致 ~
  • 测试的模型包括:1 embedding layers, 1 output layers; 2 Querys and Keys; 1 Values; 1 attn_proj; 2 MLPs; 覆盖了所有的分支(torch 的 adam_mini 有个其他分支,逻辑跟之前的一样,只不过 torch 那边做了分布式处理,而 paddle 这里没有涉及)

个人感觉 paddle 与 torch 应该是对齐了,两者的趋势一致,如果有算法逻辑的错误的话,应该没几个 step 就跑飞了 ~

@DrownFish19 帮忙看看有木有问题?!~ 感谢 ~~~ 🤟🤟🤟

megemini avatar Jun 06 '25 13:06 megemini

  1. 辛苦控制随机性之后来确认模型存在的误差,可以增加以下参数:
# 通用环境变量,避免随机性
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_embedding_deterministic=1
export FLAGS_cudnn_deterministic=1

# 并行计算环境变量,避免随机性
export Flags_mp_aysnc_allreduce=1
export Flags_skip_mp_c_identity=1
export FLAGS_shard_norm_align_dp=0
export FLAGS_shard_use_reduce=1
export FLAGS_sync_before_allreduce=1
  1. 首位loss需对齐,然后补充完整loss数据,step 10 粒度过于大,无法确认模型是否完整对齐。
  2. 模型误差过高,实现不足以信服,需增加上述随机性控制之后,确定模型误差。模型采用fp32初始化尚且精度误差在1e-3(正常范围应该是1e-5),需要补充bf16下的模型误差(正常范围应该是1-3到1e-2之间)。

DrownFish19 avatar Jun 10 '25 03:06 DrownFish19

Update 20250610

  • 修复 moment 的 shape,对齐 torch

之前确实有问题!排查了一天,终于找到问题了 ~ 原因是:

  • torch 的 linear 层 torch.nn.Linear(in_features, out_features) 的 weight 是 (out_features,in_features)
  • 但是,embd 层却不是!!! adamw_mini 原作者将 linear 层与 embd 层一起处理! torch 没问题,但是 paddle 这样做就错了!

p.s. 应该不存在随机问题,因为,测试脚本一开始会将 torch 的权重复制给 paddle ~

太坑了 ... ... 虽然之前测试文件中已经将 torch 的权重 转置 后复制给 paddle 了,但是,paddle 这边在算法层需要单独处理 embd 层与 linear 层:


                if any(embd_name in name for embd_name in self.embd_names):
                    mom2 = mom2 * beta2 + (1.0 - beta2) * (grad * grad).mean(axis=1, keepdim=True)
                else:
                    mom2 = mom2 * beta2 + (1.0 - beta2) * (grad * grad).mean(axis=0, keepdim=True)

重新逐个比对 shape 后,又修改了几个相关的地方,现在精度应该对齐了 ~

先附上测试脚本,与之前基本上一样,统一了 DTYPE,并增加 layernorm 层:


import paddle
import torch
import numpy as np
import matplotlib.pyplot as plt
from adam_mini import Adam_mini
from optimizer import AdamWMini

# Set random seeds for reproducibility
SEED = 1
np.random.seed(SEED)
paddle.seed(SEED)
torch.manual_seed(SEED)
STEPS = 10
VOCAB_SIZE = 100
DTYPE = 'float32'

paddle.set_default_dtype(DTYPE)
torch.set_default_dtype(torch.float16 if DTYPE == 'float16' else torch.float32)

# Set default devices
paddle.set_device('gpu:0')
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SimpleTransformerPaddle(paddle.nn.Layer):
    def __init__(self, dim=2048, n_heads=32):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        
        # Embedding layer
        self.embd = paddle.nn.Embedding(VOCAB_SIZE, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # Query/Key/Value projections
        self.wq = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        self.wk = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        self.wv = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # Attention projection
        self.wo = paddle.nn.Linear(dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # LayerNorm layers
        self.ln1 = paddle.nn.LayerNorm(dim)
        self.ln2 = paddle.nn.LayerNorm(dim)
        
        # MLP layers
        self.mlp = paddle.nn.Sequential(
            paddle.nn.Linear(dim, 4*dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02)),
            paddle.nn.ReLU(),
            paddle.nn.Linear(4*dim, dim, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        )
        
        # Output layer
        self.lm_head = paddle.nn.Linear(dim, VOCAB_SIZE, weight_attr=paddle.nn.initializer.Normal(mean=0.0, std=0.02))
        
        # Bias parameters
        self.bias = paddle.create_parameter([dim], dtype=DTYPE, default_initializer=paddle.nn.initializer.Constant(value=0.0))

    def forward(self, input_ids):
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]
        
        # Embedding
        hidden_states = self.embd(input_ids)  # [batch_size, seq_len, dim]

        # print('1='*10)
        # print(input_ids.shape, input_ids.sum())
        # print(hidden_states.shape, hidden_states.sum())
        
        # Query/Key/Value projections and reshape for multi-head attention
        query = self.wq(hidden_states)  # [batch_size, seq_len, dim]
        key = self.wk(hidden_states)    # [batch_size, seq_len, dim]
        value = self.wv(hidden_states)  # [batch_size, seq_len, dim]
        
        # print('2='*10)
        # print(query.shape, query.sum())
        # print(key.shape, key.sum())
        # print(value.shape, value.sum())

        # Reshape to [batch_size, seq_len, n_heads, head_dim]
        query = query.reshape([batch_size, seq_len, self.n_heads, self.head_dim])
        key = key.reshape([batch_size, seq_len, self.n_heads, self.head_dim])
        value = value.reshape([batch_size, seq_len, self.n_heads, self.head_dim])

        # Transpose to [batch_size, n_heads, seq_len, head_dim]
        query = query.transpose([0, 2, 1, 3])
        key = key.transpose([0, 2, 1, 3])
        value = value.transpose([0, 2, 1, 3])
        
        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn_weights = paddle.matmul(query * scale, key.transpose([0, 1, 3, 2]))
        attn_weights = paddle.nn.functional.softmax(attn_weights, axis=-1)
        
        # print('3='*10)
        # print(scale)
        # print(attn_weights.shape, attn_weights.sum())

        # Apply attention to values
        attn_output = paddle.matmul(attn_weights, value)  # [batch_size, n_heads, seq_len, head_dim]
        
        # Reshape back to [batch_size, seq_len, dim]
        attn_output = attn_output.transpose([0, 2, 1, 3])
        attn_output = attn_output.reshape([batch_size, seq_len, self.dim])
        
        # Attention output projection with residual connection and layer norm
        attn_output = self.wo(attn_output)
        hidden_states = self.ln1(hidden_states + attn_output)
        
        # print('4='*10)
        # print(attn_output.shape, attn_output.sum())

        # Feed forward with residual connection and layer norm
        feed_forward = self.mlp(hidden_states)
        hidden_states = self.ln2(hidden_states + feed_forward)
        
        # print('5='*10)
        # print(feed_forward.shape, feed_forward.mean().numpy())

        # Output
        output = self.lm_head(hidden_states + self.bias)

        # print('6='*10)
        # print(output.shape, output.mean().numpy())

        return output


class SimpleTransformerTorch(torch.nn.Module):
    def __init__(self, dim=2048, n_heads=32):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        
        # Embedding layer
        self.embd = torch.nn.Embedding(VOCAB_SIZE, dim)
        torch.nn.init.normal_(self.embd.weight, mean=0.0, std=0.02)
        
        # Query/Key/Value projections
        self.wq = torch.nn.Linear(dim, dim)
        self.wk = torch.nn.Linear(dim, dim)
        self.wv = torch.nn.Linear(dim, dim)
        torch.nn.init.normal_(self.wq.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.wk.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.wv.weight, mean=0.0, std=0.02)
        
        # Attention projection
        self.wo = torch.nn.Linear(dim, dim)
        torch.nn.init.normal_(self.wo.weight, mean=0.0, std=0.02)
        
        # LayerNorm layers
        self.ln1 = torch.nn.LayerNorm(dim)
        self.ln2 = torch.nn.LayerNorm(dim)
        
        # MLP layers
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(dim, 4*dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4*dim, dim)
        )
        torch.nn.init.normal_(self.mlp[0].weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.mlp[-1].weight, mean=0.0, std=0.02)
    
        # Output layer
        self.lm_head = torch.nn.Linear(dim, VOCAB_SIZE)
        torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.02)
        
        # Bias parameters
        self.bias = torch.nn.Parameter(torch.zeros(dim))

    def forward(self, input_ids):
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]
        
        # Embedding
        hidden_states = self.embd(input_ids)  # [batch_size, seq_len, dim]
        
        # print('1='*10)
        # print(input_ids.shape, input_ids.sum())
        # print(hidden_states.shape, hidden_states.sum())

        # Query/Key/Value projections and reshape for multi-head attention
        query = self.wq(hidden_states)  # [batch_size, seq_len, dim]
        key = self.wk(hidden_states)    # [batch_size, seq_len, dim]
        value = self.wv(hidden_states)  # [batch_size, seq_len, dim]
        
        # print('2='*10)
        # print(query.shape, query.sum())
        # print(key.shape, key.sum())
        # print(value.shape, value.sum())

        # Reshape to [batch_size, seq_len, n_heads, head_dim]
        query = query.view(batch_size, seq_len, self.n_heads, self.head_dim)
        key = key.view(batch_size, seq_len, self.n_heads, self.head_dim)
        value = value.view(batch_size, seq_len, self.n_heads, self.head_dim)
        
        # Transpose to [batch_size, n_heads, seq_len, head_dim]
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        
        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn_weights = torch.matmul(query * scale, key.transpose(-2, -1))
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
        
        # print('3='*10)
        # print(scale)
        # print(attn_weights.shape, attn_weights.sum())

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, value)  # [batch_size, n_heads, seq_len, head_dim]
        
        # Reshape back to [batch_size, seq_len, dim]
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(batch_size, seq_len, self.dim)
        
        # Attention output projection with residual connection and layer norm
        attn_output = self.wo(attn_output)
        hidden_states = self.ln1(hidden_states + attn_output)
        
        # print('4='*10)
        # print(attn_output.shape, attn_output.sum())

        # Feed forward with residual connection and layer norm
        feed_forward = self.mlp(hidden_states)
        hidden_states = self.ln2(hidden_states + feed_forward)
        
        # print('5='*10)
        # print(feed_forward.shape, feed_forward.mean().detach().cpu().numpy())

        # Output
        output = self.lm_head(hidden_states + self.bias)

        # print('6='*10)
        # print(output.shape, output.mean().detach().cpu().numpy())

        return output

def generate_data(batch_size=32, seq_len=64, vocab_size=VOCAB_SIZE):
    x = np.random.randint(0, vocab_size, size=(batch_size, seq_len))
    y = np.random.randint(0, vocab_size, size=(batch_size, seq_len))
    return x, y

train_data = []
for _ in range(STEPS):
    x_np, y_np = generate_data()
    train_data.append((x_np, y_np))

def train_model(config, steps=STEPS, use_adamw_mini=True):
    losses_both = {}

    model_t = SimpleTransformerTorch(dim=config['dim'], n_heads=config['n_heads']).to(torch_device)
    model_p = SimpleTransformerPaddle(dim=config['dim'], n_heads=config['n_heads'])
    
    # Copy parameters from PyTorch model to PaddlePaddle model
    for (name_t, param_t), (name_p, param_p) in zip(model_t.named_parameters(), model_p.named_parameters()):
        # print(f"\nCopying: {name_t} -> {name_p}")
        # print(f"Shapes before: {param_t.shape} -> {param_p.shape}")
        
        param_numpy = param_t.detach().cpu().numpy()
        
        # For linear layers' weights, we need to transpose
        is_weight = any(x in name_t for x in ['weight', 'w_0'])
        is_linear = (
            'mlp' in name_t or  # MLP layer
            'linear' in name_p.lower() or  # Regular linear layers
            any(x in name_t for x in ['wq', 'wk', 'wv', 'wo', 'lm_head'])  # Special layers
        )
        
        # print(f"is_weight: {is_weight}, is_linear: {is_linear}")
        
        if is_weight and is_linear:
            # print(f"Transposing parameter")
            param_numpy = param_numpy.T
            # print(f"Transposed shape: {param_numpy.shape}")
        
        paddle_shape = list(param_p.shape)
        numpy_shape = list(param_numpy.shape)
        if paddle_shape != numpy_shape:
            raise ValueError(f"Shape mismatch after processing: Paddle shape {paddle_shape} != Numpy shape {numpy_shape}")
            
        param_p.set_value(paddle.to_tensor(param_numpy))
        # print(f"Shapes after: {param_t.shape} -> {param_p.shape}")
        # print("-" * 50)

    out_paddle = []
    out_torch = []

    if use_adamw_mini:
        print("Testing AdamWMini (Paddle)...")
    else:
        print("Testing AdamW (Paddle)...")

    # Train Paddle model
    criterion_p = paddle.nn.CrossEntropyLoss()
    model_p.train()
    
    if use_adamw_mini:
        optimizer_p = AdamWMini(
            named_parameters=model_p.named_parameters(),
            learning_rate=config['lr'],
            beta1=config['beta1'],
            beta2=config['beta2'],
            epsilon=config['epsilon'],
            weight_decay=config['weight_decay'],
            dim=config['dim'],
            n_heads=config['n_heads'],
            use_lowprecision_moment=DTYPE == 'float16'
        )
    else:
        optimizer_p = paddle.optimizer.AdamW(
            parameters=model_p.parameters(),
            learning_rate=config['lr'],
            beta1=config['beta1'],
            beta2=config['beta2'],
            epsilon=config['epsilon'],
            weight_decay=config['weight_decay'],
            use_lowprecision_moment=DTYPE == 'float16'
        )

    losses = []
    for step in range(steps):
        x_np, y_np = train_data[step]
        x = paddle.to_tensor(x_np, dtype='int64', place='gpu:0')
        y = paddle.to_tensor(y_np, dtype='int64', place='gpu:0')
        
        out = model_p(x)  # [batch_size, seq_len, vocab_size]
        out = out.reshape([-1, out.shape[-1]])  # [batch_size*seq_len, vocab_size]
        y = y.reshape([-1])  # [batch_size*seq_len]
        loss = criterion_p(out, y)

        # print('o'*20, out.detach().cpu().numpy().reshape(-1)[:10])
        # print('l'*20, out.detach().cpu().numpy().mean(), y.detach().cpu().numpy().sum(), loss.detach().cpu().numpy())

        out_paddle.append(out.detach().cpu().numpy().reshape(-1)[:5])

        model_p.clear_gradients()
        loss.backward()
        optimizer_p.step()
        losses.append(float(loss.numpy()))
        
        if (step+1) % 1 == 0:
            print(f'step {step+1}, Paddle Loss: {float(loss.numpy()):.4f}')
            
    losses_both['paddle'] = losses

    if use_adamw_mini:
        print("Testing Adam_mini (PyTorch)...")
    else:
        print("Testing AdamW (PyTorch)...")

    # Train Torch model
    criterion_t = torch.nn.CrossEntropyLoss()
    model_t.train()
    
    if use_adamw_mini:
        optimizer_t = Adam_mini(
            named_parameters = model_t.named_parameters(),
            lr=config['lr'],
            betas= (config['beta1'], config['beta2']),
            eps=config['epsilon'],
            weight_decay=config['weight_decay'],
            dim=config['dim'],
            n_heads=config['n_heads']
        )
    else:
        optimizer_t = torch.optim.AdamW(
            params = model_t.parameters(),
            lr=config['lr'],
            betas= (config['beta1'], config['beta2']),
            eps=config['epsilon'],
            weight_decay=config['weight_decay'],
        )

    
    losses = []
    for step in range(steps):
        x_np, y_np = train_data[step]
        x = torch.tensor(x_np, dtype=torch.long, device=torch_device)
        y = torch.tensor(y_np, dtype=torch.long, device=torch_device)
        
        out = model_t(x)  # [batch_size, seq_len, vocab_size]
        out = out.reshape(-1, out.shape[-1])  # [batch_size*seq_len, vocab_size]
        y = y.reshape(-1)  # [batch_size*seq_len]
        loss = criterion_t(out, y)

        # print('o'*20, out.detach().cpu().numpy().reshape(-1)[:10])
        # print('l'*20, out.detach().cpu().numpy().mean(), y.detach().cpu().numpy().sum(), loss.detach().cpu().numpy())
        out_torch.append(out.detach().cpu().numpy().reshape(-1)[:5])

        optimizer_t.zero_grad()
        loss.backward()
        optimizer_t.step()
        losses.append(float(loss.detach().cpu().numpy()))
        
        if (step+1) % 1 == 0:
            print(f'step {step+1}, Torch Loss: {float(loss.detach().cpu().numpy()):.4f}')
            
    losses_both['torch'] = losses

    print('*'*20)
    print('Compare outputs:')
    for i in range(STEPS):
        print(f'step {i+1}:')
        print('paddle:', out_paddle[i])
        print('torch: ', out_torch[i])

    return losses_both

# 测试配置
configs = [
    {
        'name': 'base_config',
        'lr': 1e-3,
        'beta1': 0.9,
        'beta2': 0.999,
        'epsilon': 1e-8,
        'weight_decay': 0,
        'dim': 2048,
        'n_heads': 32
    },
]

# 运行比较
results = {}
for config in configs:
    print(f"\nTesting config: {config['name']}")
    
    print('='*50)
    print('Test adamw mini...')
    loss = train_model(config)

    results[config['name']] = {
        'Adam_mini': loss['torch'],
        'AdamWMini': loss['paddle']
    }

    print('='*50)
    print('Test original adamw...')
    loss = train_model(config, use_adamw_mini=False)

    results[config['name']] = {
        'Adam_mini': loss['torch'],
        'AdamWMini': loss['paddle']
    }


再附上日志:


Testing config: base_config
==================================================
Test adamw mini...
W0610 22:52:06.941300 112176 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.2, Runtime API Version: 11.7
W0610 22:52:06.941711 112176 gpu_resources.cc:164] device: 0, cuDNN Version: 8.5.
Testing AdamWMini (Paddle)...
W0610 22:52:07.737118 112176 gpu_resources.cc:306] WARNING: device: 0. The installed Paddle is compiled with CUDNN 8.9, but CUDNN version in your machine is 8.5, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDNN version.

Adam-mini found blocks:
- 1 embedding layers
- 1 output layers
- 2 Query and Key layers
- 1 Value layers
- 1 Attention projection layers
- 2 MLP layers

step 1, Paddle Loss: 5.0468
step 2, Paddle Loss: 8.0434
step 3, Paddle Loss: 6.9679
step 4, Paddle Loss: 6.1530
step 5, Paddle Loss: 5.9202
step 6, Paddle Loss: 5.7181
step 7, Paddle Loss: 5.4434
step 8, Paddle Loss: 5.4844
step 9, Paddle Loss: 5.2945
step 10, Paddle Loss: 5.3582
Testing Adam_mini (PyTorch)...
Adam-mini found the param block with name: bias torch.Size([2048])
Adam-mini found the param block with name: embd.weight torch.Size([100, 2048])
Adam-mini found the param block with name: wq.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wq.bias torch.Size([2048])
Adam-mini found the param block with name: wk.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wk.bias torch.Size([2048])
Adam-mini found the param block with name: wv.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wv.bias torch.Size([2048])
Adam-mini found the param block with name: wo.weight torch.Size([2048, 2048])
Adam-mini found the param block with name: wo.bias torch.Size([2048])
Adam-mini found the param block with name: ln1.weight torch.Size([2048])
Adam-mini found the param block with name: ln1.bias torch.Size([2048])
Adam-mini found the param block with name: ln2.weight torch.Size([2048])
Adam-mini found the param block with name: ln2.bias torch.Size([2048])
Adam-mini found the param block with name: mlp.0.weight torch.Size([8192, 2048])
Adam-mini found the param block with name: mlp.0.bias torch.Size([8192])
Adam-mini found the param block with name: mlp.2.weight torch.Size([2048, 8192])
Adam-mini found the param block with name: mlp.2.bias torch.Size([2048])
Adam-mini found the param block with name: lm_head.weight torch.Size([100, 2048])
Adam-mini found the param block with name: lm_head.bias torch.Size([100])
Adam-mini found 1 embedding layers, 1 output layers; 2 Querys and Keys;  1 Values;  1 attn_proj;  2 MLPs;
step 1, Torch Loss: 5.0468
step 2, Torch Loss: 8.0434
step 3, Torch Loss: 6.9679
step 4, Torch Loss: 6.1530
step 5, Torch Loss: 5.9202
step 6, Torch Loss: 5.7181
step 7, Torch Loss: 5.4434
step 8, Torch Loss: 5.4844
step 9, Torch Loss: 5.2945
step 10, Torch Loss: 5.3582
********************
Compare outputs:
step 1:
paddle: [ 0.6792311   0.79050744 -0.12007625  0.6374513   1.3632734 ]
torch:  [ 0.6792307   0.7905068  -0.12007573  0.6374507   1.3632755 ]
step 2:
paddle: [ 2.388264  -2.551482   3.1278257 -5.2603097 -9.545789 ]
torch:  [ 2.3882647 -2.5514843  3.1278288 -5.2603164 -9.545795 ]
step 3:
paddle: [ 3.4365075  -0.41926917  2.415759   -4.0009174  -8.514009  ]
torch:  [ 3.436512   -0.41924933  2.4157746  -4.000907   -8.51399   ]
step 4:
paddle: [ 2.5044622  1.6495086  2.956845  -1.741157  -7.247345 ]
torch:  [ 2.5044317  1.6495492  2.9568903 -1.7411256 -7.2472587]
step 5:
paddle: [ 1.9464118   3.5224404   1.593899    0.26571018 -5.6921353 ]
torch:  [ 1.9464757   3.5224771   1.5938914   0.26569998 -5.691975  ]
step 6:
paddle: [ 1.78667    2.8393915  1.4273245  2.2268338 -3.7377687]
torch:  [ 1.7866793  2.8393207  1.4273024  2.2268624 -3.7375774]
step 7:
paddle: [ 2.1035624  1.9680623  2.009671   3.9021509 -1.8737448]
torch:  [ 2.1035743  1.9680486  2.0096316  3.9021606 -1.8734872]
step 8:
paddle: [2.256593   1.3834429  2.2933404  2.8511925  0.10505474]
torch:  [2.2565947 1.3834566 2.2933512 2.8511777 0.1053707]
step 9:
paddle: [1.8720278 1.042617  2.2469335 1.6705302 2.219232 ]
torch:  [1.8720777 1.0426742 2.2469792 1.6705241 2.2195644]
step 10:
paddle: [1.6803029  1.1076443  2.2874732  0.96978676 3.8937752 ]
torch:  [1.6803467  1.1077304  2.2874951  0.96976256 3.8939419 ]
==================================================
Test original adamw...
Testing AdamW (Paddle)...
step 1, Paddle Loss: 4.9998
step 2, Paddle Loss: 8.9866
step 3, Paddle Loss: 7.3942
step 4, Paddle Loss: 6.5954
step 5, Paddle Loss: 6.0405
step 6, Paddle Loss: 5.8503
step 7, Paddle Loss: 5.7641
step 8, Paddle Loss: 5.6170
step 9, Paddle Loss: 5.6252
step 10, Paddle Loss: 5.5295
Testing AdamW (PyTorch)...
step 1, Torch Loss: 4.9998
step 2, Torch Loss: 8.9866
step 3, Torch Loss: 7.3942
step 4, Torch Loss: 6.5954
step 5, Torch Loss: 6.0405
step 6, Torch Loss: 5.8503
step 7, Torch Loss: 5.7641
step 8, Torch Loss: 5.6170
step 9, Torch Loss: 5.6252
step 10, Torch Loss: 5.5295
********************
Compare outputs:
step 1:
paddle: [ 0.14020315  1.0092521  -0.50811213 -0.757113    0.6712453 ]
torch:  [ 0.14020398  1.0092522  -0.50811195 -0.7571134   0.67124397]
step 2:
paddle: [ 7.7707353 -6.8123293  2.0890477  1.5774026 -5.867026 ]
torch:  [ 7.770744  -6.8123293  2.089048   1.5773983 -5.8670235]
step 3:
paddle: [-8.97218   -4.606195   3.1028814  3.3408732 -4.797612 ]
torch:  [-8.972288  -4.6061926  3.102943   3.3408554 -4.797567 ]
step 4:
paddle: [-13.143077   -2.8907413   4.363377    4.6695986  -3.3483245]
torch:  [-13.143172   -2.8907356   4.3633885   4.6695294  -3.3482106]
step 5:
paddle: [-14.768982   -1.4378625   2.5297182   2.6471355  -1.8930098]
torch:  [-14.769129   -1.4378538   2.5296354   2.647159   -1.8929064]
step 6:
paddle: [-15.746719     0.07474444   1.2834598    1.5092249   -0.05895834]
torch:  [-15.746851     0.07490155   1.2833009    1.5092876   -0.05884233]
step 7:
paddle: [-16.3188      1.7830147   0.7786581   0.9820332   1.7110273]
torch:  [-16.319        1.7831819    0.77852666   0.98210263   1.7112223 ]
step 8:
paddle: [-16.650913     3.5396793    0.57224363   0.89062434   3.47959   ]
torch:  [-16.651108     3.539764     0.57193947   0.89074177   3.4797454 ]
step 9:
paddle: [-16.750427    4.184098    0.8127401   1.0824689   4.310015 ]
torch:  [-16.750587    4.183991    0.8124665   1.0825541   4.3099046]
step 10:
paddle: [-16.657671    3.5375113   1.458569    1.574932    3.4111197]
torch:  [-16.657343    3.537442    1.4586084   1.5751177   3.4113874]



这里同时对比了 loss 与 output,也同时对比了原始的 AdamW 算法,目前看精度基本一样,output 在后面基本精度在 1e-4 左右 ~ AdamwMini 与 AdamW 基本都是这么个精度 ~

另外,float16 测不了 ~ 只所以这次重新贴上测试脚本,是因为,本来打算只修改 DTYPE = 'float16' 便完成 float16 的测试,但是,AdamW 不行:


Traceback (most recent call last):
  File "/home/shun/Documents/Projects/paddle/megemini/PaddleNLP/tmp/compare_optimizers.py", line 421, in <module>
    loss = train_model(config)
  File "/home/shun/Documents/Projects/paddle/megemini/PaddleNLP/tmp/compare_optimizers.py", line 329, in train_model
    optimizer_p.step()
  File "/home/shun/venv39dev/lib/python3.9/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/home/shun/venv39dev/lib/python3.9/site-packages/paddle/base/dygraph/base.py", line 386, in __impl__
    return func(*args, **kwargs)
  File "/home/shun/venv39dev/lib/python3.9/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/home/shun/venv39dev/lib/python3.9/site-packages/paddle/base/wrapped_decorator.py", line 40, in __impl__
    return wrapped_func(*args, **kwargs)
  File "/home/shun/venv39dev/lib/python3.9/site-packages/paddle/base/framework.py", line 718, in __impl__
    return func(*args, **kwargs)
  File "/home/shun/venv39dev/lib/python3.9/site-packages/paddle/optimizer/adamw.py", line 684, in step
    optimize_ops = self._apply_optimize(
  File "/home/shun/venv39dev/lib/python3.9/site-packages/paddle/optimizer/optimizer.py", line 1685, in _apply_optimize
    optimize_ops = self._create_optimization_pass(
  File "/home/shun/venv39dev/lib/python3.9/site-packages/paddle/optimizer/optimizer.py", line 1363, in _create_optimization_pass
    self._append_optimize_op(
  File "/home/shun/Documents/Projects/paddle/megemini/PaddleNLP/tmp/optimizer.py", line 208, in _append_optimize_op
    self.adamw_python(
  File "/home/shun/Documents/Projects/paddle/megemini/PaddleNLP/tmp/optimizer.py", line 261, in adamw_python
    _, _, _, _, _, _, _ = _C_ops.adamw_(
ValueError: (InvalidArgument) The type of data we are trying to retrieve (float32) does not match the type of data (float16) currently contained in the container.
  [Hint: Expected dtype() == phi::CppTypeToDataType<T>::Type(), but received dtype():15 != phi::CppTypeToDataType<T>::Type():10.] (at ../paddle/phi/core/dense_tensor.cc:153)


我查看了输入参数,都是 float16 ,不清楚为啥 _C_ops.adamw_ 不行 ???

@DrownFish19 帮忙看一下 ~ 感谢!:)

megemini avatar Jun 10 '25 15:06 megemini

我查看了输入参数,都是 float16 ,不清楚为啥 C_ops.adamw 不行 ??? 这个的问题我也不是很清楚具体原因,之前遇到过一次是输入 C_ops.adamw 的数据是fp32,可以看一下输入的参数是否存在fp32类型数据。

CI问题可以重跑一下看看

DrownFish19 avatar Jun 11 '25 13:06 DrownFish19

@DrownFish19 ci 过了 ~ 看看还有啥要搞的?~ 🤗

megemini avatar Jun 12 '25 05:06 megemini

增加adamw_mini的CI case,可增加一个case可以参考上面的测试用例,ci测试保证adammini能跑通

lugimzzz avatar Jun 18 '25 04:06 lugimzzz

增加adamw_mini的CI case,可增加一个case可以参考上面的测试用例,ci测试保证adammini能跑通

已增加一个 test ~ 模型使用的是上面测试用的模型,保证了所有分支的覆盖 ~

以下为测试日志:


> python -m unittest test_adamw_mini.py
/home/shun/venv39dev/lib/python3.9/site-packages/paddle/utils/cpp_extension/extension_utils.py:711: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
  warnings.warn(warning_message)
W0618 12:52:03.121173 33099 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.2, Runtime API Version: 11.7
W0618 12:52:03.121762 33099 gpu_resources.cc:164] device: 0, cuDNN Version: 8.5.
W0618 12:52:03.826521 33099 gpu_resources.cc:306] WARNING: device: 0. The installed Paddle is compiled with CUDNN 8.9, but CUDNN version in your machine is 8.5, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDNN version.
/home/shun/venv39dev/lib/python3.9/site-packages/paddle/base/dygraph/math_op_patch.py:183: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  return int(np.array(var))
[2025-06-18 12:52:04,136] [    INFO] - 
Adam-mini found blocks:
[2025-06-18 12:52:04,136] [    INFO] - - 1 embedding layers
[2025-06-18 12:52:04,136] [    INFO] - - 1 output layers
[2025-06-18 12:52:04,136] [    INFO] - - 2 Query and Key layers
[2025-06-18 12:52:04,136] [    INFO] - - 1 Value layers
[2025-06-18 12:52:04,136] [    INFO] - - 1 Attention projection layers
[2025-06-18 12:52:04,136] [    INFO] - - 2 MLP layers

.
----------------------------------------------------------------------
Ran 1 test in 1.274s

OK

megemini avatar Jun 18 '25 05:06 megemini

@DrownFish19 能不能帮忙看看 PaddleNLP-CI-Unittest-GPU 这个 CI ~ 运行超时了 ~ 是我这个 pr 引起的?应该不至于吧 ... ...

megemini avatar Jun 19 '25 07:06 megemini