LMOps icon indicating copy to clipboard operation
LMOps copied to clipboard

assert all((~torch.isinf(scores.view(-1))) & (~torch.isnan(scores.view(-1)))) [rank0]: AssertionError

Open Harryjun opened this issue 1 year ago • 21 comments


OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --teacher-model-path ${TEACHER_CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --teacher-ckpt-name ${TEACHER_CKPT_NAME}"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
OPTS+=" --n-nodes ${NNODES}"
OPTS+=" --model-type qwen2"
OPTS+=" --teacher-model-fp16"
OPTS+=" --gradient-checkpointing"
# OPTS+=" --model-parallel"
# OPTS+=" --model-parallel-size ${MP_SIZE}"
# data
OPTS+=" --prompt-data-dir ${PROMPT_DATA_DIR}"
OPTS+=" --only-prompt"

# OPTS+=" --lm-data-dir ${LM_DATA_DIR}"
OPTS+=" --dev-num 1000"
OPTS+=" --num-workers 0"
# hp
OPTS+=" --epochs 3"
# OPTS+=" --total-iters 5000"
OPTS+=" --kd-ratio 0.5"
OPTS+=" --batch-size ${BATCH_SIZE}"
OPTS+=" --lr 5e-6"
OPTS+=" --lr-min 5e-6"
OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --max-length 4596"
OPTS+=" --max-prompt-length 4096"
OPTS+=" --warmup-iters 100"
OPTS+=" --scheduler-name cosine_trm"
# runtime
OPTS+=" --save ${SAVE_PATH}"
OPTS+=" --seed 10"
OPTS+=" --seed-ppo 42"
OPTS+=" --seed-lm 7"
OPTS+=" --save-interval 500"
OPTS+=" --eval-interval 100"
OPTS+=" --log-interval 16"
OPTS+=" --mid-log-num 1"
# ppo
OPTS+=" --type minillm"
OPTS+=" --ppo-epochs 4"
OPTS+=" --num-rollouts 256"
OPTS+=" --chunk-size ${CHUNK_SIZE}"

# OPTS+=" --type kd"

# minillm
OPTS+=" --length-norm"
OPTS+=" --single-step-reg"
OPTS+=" --teacher-mixed-alpha 0.4"
# reward
OPTS+=" --reward-scaling 0.6"
OPTS+=" --cliprange-reward 100"
# gen
# OPTS+=" --do-sample"
# OPTS+=" --top-k 0"
# OPTS+=" --top-p 1.0"
# OPTS+=" --temperature 1.0"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config_zero1_fp16.json"

export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/train_minillm.py ${OPTS} $@"
# CMD="python3 ${BASE_PATH}/train_minillm.py ${OPTS}"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}

报错 Image

Harryjun avatar Dec 12 '24 17:12 Harryjun

if i change to

OPTS+=" --max-length 4097"
OPTS+=" --max-prompt-length 4096"

no problem why

Harryjun avatar Dec 12 '24 17:12 Harryjun

when selection_value is 0, the next_state_value is nan, that is a bug?

Harryjun avatar Dec 13 '24 04:12 Harryjun

if i change to

OPTS+=" --max-length 4097"
OPTS+=" --max-prompt-length 4096"

no problem why

感谢哥们帮我解决这个问题,我还以为qwen不能跑很头疼

liuchen6667 avatar Dec 13 '24 07:12 liuchen6667

when selection_value is 0, the next_state_value is nan, that is a bug?

It's a little bit weird because next_state_value is obtained by taking torch.logsumexp for current_logits. torch.logsumexp outputs nan or inf only if there are nan or inf in current_logits. You can check this by adding

print(all((~torch.isinf(current_logits.view(-1))) & (~torch.isnan(current_logits.view(-1)))))

after line 61. If the output is true, probably you have set some values in the logits to 'inf' or '-inf'.

t1101675 avatar Dec 13 '24 15:12 t1101675

@liuchen6667 This is wrong. My solution only limits the output to only one bit, but it is not the correct approach. I think as @t1101675 said, it may be that the qwen output is empty or the terminator has be changed to inf. It may be a model problem, or the code may be incompatible. If more people have this problem, i think it is most likely that the code is incompatible.

Harryjun avatar Dec 15 '24 05:12 Harryjun

@liuchen6667 This is wrong. My solution only limits the output to only one bit, but it is not the correct approach. I think as @t1101675 said, it may be that the qwen output is empty or the terminator has be changed to inf. It may be a model problem, or the code may be incompatible. If more people have this problem, i think it is most likely that the code is incompatible.

确实啊,原论文也没有提到qwen2.5,感觉就是没有适配,难顶

liuchen6667 avatar Dec 15 '24 06:12 liuchen6667

@t1101675 Hi, Can you adapt the qwen model?

Harryjun avatar Dec 16 '24 03:12 Harryjun

I tried printing mask, selection_value, next_state_value, then the mask is

([[ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0') tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')

the selection_value is

tensor([[25.5312, 28.5781,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [31.6406, 28.4844,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [30.3594, 28.5781,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [28.8281, 28.6719,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [29.0000, 28.4375,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
        [34.7812, 28.6250,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.float16)

the next_state_value is

tensor([[25.7344, 28.5781,     nan,  ...,     nan,     nan,     nan],
        [31.6406, 28.4844,     nan,  ...,     nan,     nan,     nan],
        [30.3594, 28.5781,     nan,  ...,     nan,     nan,     nan],
        ...,
        [28.8594, 28.6719,     nan,  ...,     nan,     nan,     nan],
        [29.0312, 28.4375,     nan,  ...,     nan,     nan,     nan],
        [34.7812, 28.6250,     nan,  ...,     nan,     nan,     nan]],

the next_state_value was masked by mask. so score = selection_value - next_state_value have inf. then how to solve it @t1101675

Harryjun avatar Dec 16 '24 09:12 Harryjun

@t1101675

        #  add by m , delete the output is 0.
        next_state_value = torch.where(torch.isinf(next_state_value), torch.zeros_like(next_state_value), next_state_value)
        next_state_value = next_state_value * mask[:, :-1]

i add this solve it @liuchen6667

Harryjun avatar Dec 16 '24 18:12 Harryjun

@t1101675

        #  add by m , delete the output is 0.
        next_state_value = torch.where(torch.isinf(next_state_value), torch.zeros_like(next_state_value), next_state_value)
        next_state_value = next_state_value * mask[:, :-1]

i add this solve it @liuchen6667

这真靠谱吗哥,我还是自己去另立炉灶吧,这代码复杂的一

liuchen6667 avatar Dec 17 '24 16:12 liuchen6667

@t1101675

        #  add by m , delete the output is 0.
        next_state_value = torch.where(torch.isinf(next_state_value), torch.zeros_like(next_state_value), next_state_value)
        next_state_value = next_state_value * mask[:, :-1]

i add this solve it @liuchen6667

I'm not certain if this solution works as intended. I suspect that current_logits may contain NaN, Inf, or extremely large values, which could cause next_state_value to become NaN after applying torch.logsumexp. I'd be happy to take a closer look if you could provide more details about the configurations you're using, such as the model, tokenization method, etc.

t1101675 avatar Dec 17 '24 20:12 t1101675

@liuchen6667 怎么样,说实话效果不效果还真不知道,但代码挺难跑通的,不知道有没有人跑通,有点浪费时间了

Harryjun avatar Dec 18 '24 09:12 Harryjun

@liuchen6667 怎么样,说实话效果不效果还真不知道,但代码挺难跑通的,不知道有没有人跑通,有点浪费时间了

原论文里边也没提到qwen,我猜就是适配问题,建议不纠结了

liuchen6667 avatar Dec 18 '24 10:12 liuchen6667

@t1101675 Image

我这个符合预期吗?tot_loss不稳定,有时候还有负数

Harryjun avatar Dec 18 '24 12:12 Harryjun

感觉不太符合预期,tot_loss 不太可能是负数

t1101675 avatar Dec 18 '24 13:12 t1101675

@t1101675 能加您个微信讨论下吗?我的:junge1300780479

Harryjun avatar Dec 18 '24 17:12 Harryjun

@liuchen6667 用了他们新代码好像没有这个问题了,我没有加lm data,只有prompt data

Harryjun avatar Dec 19 '24 08:12 Harryjun

@liuchen6667 也可以加个微信交流下 我的:junge1300780479

Harryjun avatar Dec 19 '24 08:12 Harryjun

@Harryjun 你好,我遇到这个问题,并且按照你的步骤做下来了;但是还是会assert all((~torch.isinf(scores.view(-1))) & (~torch.isnan(scores.view(-1)))) [rank0]: AssertionError 出现这个错误,请问你后来怎么解决的呢

chenhao-stick-to avatar Jan 17 '25 10:01 chenhao-stick-to

这么解决的

aqe670 avatar Jan 22 '25 04:01 aqe670

@t1101675 请问您是否搞清楚这个问题出现的原因了呢?我在进行llama 7B-13B的无LM DATASET的MiniLLM蒸馏时同样遇到了这个问题

pipilia avatar Jun 09 '25 12:06 pipilia