[Bug] 多机进行MPO时出现shape mismatch问题
Checklist
- [X] 1. I have searched related issues but cannot get the expected help.
- [X] 2. The bug has not been fixed in the latest version.
- [X] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
当我使用单机8卡进行InternVL2.5-8B-MPO时,可以正常训练,但当我使用3机24卡时,出现warning: shape mismatch: value tensor of shape [4608, 4096] cannot be broadcast to indexing result of shape [1098, 4096], input_embeds[selected].shape=torch.Size([1098, 4096]), vit_embeds.shape=torch.Size([4608, 4096]),并且loss为0
Reproduction
torchrun --master_port=${MASTER_PORT}
--nnodes=${NNODES}
--node_rank=${NODE_RANK}
--master_addr=${MASTER_ADDR}
--nproc_per_node=${NPROC_PER_NODE}
internvl/train/internvl_chat_dpo.py
--model_name_or_path "./pretrained/InternVL2_5-8B-MPO"
--conv_style "internvl2_5"
--use_fast_tokenizer False
--output_dir ${OUTPUT_DIR}
--meta_path ./shell/data/relevance_v2.4.1_dpo.json
--overwrite_output_dir True
--force_image_size 448
--down_sample_ratio 0.5
--drop_path_rate 0.1
--pad2square False
--freeze_llm False
--freeze_mlp False
--freeze_backbone False
--vision_select_layer -1
--use_data_resampling False
--dataloader_num_workers 8
--bf16 True
--num_train_epochs 1
--per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE}
--gradient_accumulation_steps ${GRADIENT_ACC}
--evaluation_strategy "no"
--save_strategy "no"
--save_steps 100
--save_total_limit 100
--learning_rate 1e-6
--weight_decay 0.05
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--max_seq_length 1024
--do_train True
--grad_checkpoint True
--group_by_length False
--dynamic_image_size True
--use_thumbnail True
--ps_version 'v2'
--deepspeed "zero_stage1_config.json"
--report_to "tensorboard"
--loss_type sigmoid,bco_pair
--sigmoid_loss_weight 0.8
--bco_pair_loss_weight 0.2
--rpo_alpha 1
--use_liger True
2>&1 | tee -a "$LOG_FILE"
Environment
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
absl-py 2.1.0 pypi_0 pypi
accelerate 1.1.1 pypi_0 pypi
addict 2.4.0 pypi_0 pypi
aiohappyeyeballs 2.4.4 pypi_0 pypi
aiohttp 3.11.9 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
annotated-types 0.7.0 pypi_0 pypi
attrs 24.2.0 pypi_0 pypi
blas 2.116 mkl conda-forge
blas-devel 3.9.0 16_linux64_mkl conda-forge
boto3 1.35.73 pypi_0 pypi
botocore 1.35.73 pypi_0 pypi
brotli-python 1.1.0 py311hfdbb021_2 conda-forge
bzip2 1.0.8 h4bc722e_7 conda-forge
ca-certificates 2024.8.30 hbcca054_0 conda-forge
certifi 2024.8.30 pyhd8ed1ab_0 conda-forge
cffi 1.17.1 py311hf29c0ef_0 conda-forge
charset-normalizer 3.4.0 pyhd8ed1ab_0 conda-forge
coloredlogs 15.0.1 pypi_0 pypi
contourpy 1.3.1 pypi_0 pypi
cuda-cudart 12.1.105 0 nvidia
cuda-cupti 12.1.105 0 nvidia
cuda-libraries 12.1.0 0 nvidia
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvtx 12.1.105 0 nvidia
cuda-opencl 12.6.77 0 nvidia
cuda-runtime 12.1.0 0 nvidia
cuda-version 12.6 3 nvidia
cycler 0.12.1 pypi_0 pypi
datasets 3.1.0 pypi_0 pypi
decord 0.6.0 pypi_0 pypi
deepspeed 0.15.4 pypi_0 pypi
dill 0.3.8 pypi_0 pypi
docstring-parser 0.16 pypi_0 pypi
einops 0.8.0 pypi_0 pypi
environs 11.0.0 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.16.1 pyhd8ed1ab_0 conda-forge
flash-attn 2.7.0.post2 pypi_0 pypi
fonttools 4.55.0 pypi_0 pypi
freetype 2.12.1 h267a509_2 conda-forge
frozenlist 1.5.0 pypi_0 pypi
fsspec 2024.9.0 pypi_0 pypi
giflib 5.2.2 hd590300_0 conda-forge
gmp 6.3.0 hac33072_2 conda-forge
gmpy2 2.1.5 py311h0f6cedb_2 conda-forge
gnutls 3.6.13 h85f3911_1 conda-forge
grpcio 1.68.1 pypi_0 pypi
h2 4.1.0 pyhd8ed1ab_0 conda-forge
hjson 3.1.0 pypi_0 pypi
hpack 4.0.0 pyh9f0ad1d_0 conda-forge
huggingface-hub 0.26.3 pypi_0 pypi
humanfriendly 10.0 pypi_0 pypi
humanize 4.11.0 pypi_0 pypi
hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge
icu 73.2 h59595ed_0 conda-forge
idna 3.10 pyhd8ed1ab_0 conda-forge
imageio 2.36.1 pypi_0 pypi
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
jmespath 1.0.1 pypi_0 pypi
jpeg 9e h166bdaf_2 conda-forge
kiwisolver 1.4.7 pypi_0 pypi
lame 3.100 h166bdaf_1003 conda-forge
lcms2 2.15 hfd0df8a_0 conda-forge
ld_impl_linux-64 2.43 h712a8e2_2 conda-forge
lerc 4.0.0 h27087fc_0 conda-forge
libblas 3.9.0 16_linux64_mkl conda-forge
libcblas 3.9.0 16_linux64_mkl conda-forge
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.11.1.6 0 nvidia
libcurand 10.3.7.77 0 nvidia
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libdeflate 1.17 h0b41bf4_0 conda-forge
libexpat 2.6.4 h5888daf_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc 14.2.0 h77fa898_1 conda-forge
libgcc-ng 14.2.0 h69a702a_1 conda-forge
libgfortran 14.2.0 h69a702a_1 conda-forge
libgfortran-ng 14.2.0 h69a702a_1 conda-forge
libgfortran5 14.2.0 hd5240d6_1 conda-forge
libhwloc 2.11.2 default_he43201b_1000 conda-forge
libiconv 1.17 hd590300_2 conda-forge
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
liblapack 3.9.0 16_linux64_mkl conda-forge
liblapacke 3.9.0 16_linux64_mkl conda-forge
libnpp 12.0.2.50 0 nvidia
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.1.105 0 nvidia
libnvjpeg 12.1.1.14 0 nvidia
libpng 1.6.43 h2797004_0 conda-forge
libsqlite 3.46.0 hde9e2c9_0 conda-forge
libstdcxx 14.2.0 hc0a3c3a_1 conda-forge
libstdcxx-ng 14.2.0 h4852527_1 conda-forge
libtiff 4.5.0 h6adf6a1_2 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libwebp 1.2.4 h1daa5a0_1 conda-forge
libwebp-base 1.2.4 h166bdaf_0 conda-forge
libxcb 1.13 h7f98852_1004 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.7 hc051c1a_1 conda-forge
libzlib 1.2.13 h4ab18f5_6 conda-forge
liger-kernel 0.4.2 pypi_0 pypi
llvm-openmp 15.0.7 h0cdce71_0 conda-forge
lmdeploy 0.6.4 pypi_0 pypi
loguru 0.7.2 pypi_0 pypi
markdown 3.7 pypi_0 pypi
markdown-it-py 3.0.0 pypi_0 pypi
markupsafe 3.0.2 py311h2dc5d0c_0 conda-forge
marshmallow 3.23.1 pypi_0 pypi
matplotlib 3.9.3 pypi_0 pypi
mdurl 0.1.2 pypi_0 pypi
mkl 2022.1.0 h84fe81f_915 conda-forge
mkl-devel 2022.1.0 ha770c72_916 conda-forge
mkl-include 2022.1.0 h84fe81f_915 conda-forge
mmengine 0.10.5 pypi_0 pypi
mpc 1.3.1 h24ddda3_1 conda-forge
mpfr 4.2.1 h90cbb55_3 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
msgpack 1.1.0 pypi_0 pypi
multidict 6.1.0 pypi_0 pypi
multiprocess 0.70.16 pypi_0 pypi
multiprocessing-logging 0.3.4 pypi_0 pypi
ncurses 6.5 he02047a_1 conda-forge
nettle 3.6 he412f7d_0 conda-forge
networkx 3.4.2 pyh267e887_2 conda-forge
ninja 1.11.1.2 pypi_0 pypi
numpy 2.1.3 py311h71ddf71_0 conda-forge
opencv-python 4.10.0.84 pypi_0 pypi
openh264 2.1.1 h780b84a_0 conda-forge
openjpeg 2.5.0 hfec8fc6_2 conda-forge
openssl 3.4.0 hb9d3cd8_0 conda-forge
packaging 24.2 pypi_0 pypi
pandas 2.2.3 pypi_0 pypi
peft 0.13.2 pypi_0 pypi
pillow 9.4.0 py311h50def17_1 conda-forge
pip 24.3.1 pyh8b19718_0 conda-forge
platformdirs 4.3.6 pypi_0 pypi
propcache 0.2.1 pypi_0 pypi
protobuf 5.29.0 pypi_0 pypi
psutil 6.1.0 pypi_0 pypi
pthread-stubs 0.4 hb9d3cd8_1002 conda-forge
py-cpuinfo 9.0.0 pypi_0 pypi
pyarrow 18.1.0 pypi_0 pypi
pycparser 2.22 pyhd8ed1ab_0 conda-forge
pydantic 2.10.2 pypi_0 pypi
pydantic-core 2.27.1 pypi_0 pypi
pygments 2.18.0 pypi_0 pypi
pyparsing 3.2.0 pypi_0 pypi
pysocks 1.7.1 pyha2e5f31_6 conda-forge
python 3.11.9 hb806964_0_cpython conda-forge
python-dateutil 2.9.0.post0 pypi_0 pypi
python-dotenv 1.0.1 pypi_0 pypi
python_abi 3.11 5_cp311 conda-forge
pytorch 2.5.1 py3.11_cuda12.1_cudnn9.1.0_0 pytorch
pytorch-cuda 12.1 ha16c6d3_6 pytorch
pytorch-mutex 1.0 cuda pytorch
pytz 2024.2 pypi_0 pypi
pyyaml 6.0.2 py311h9ecbd09_1 conda-forge
readline 8.2 h8228510_1 conda-forge
regex 2024.11.6 pypi_0 pypi
requests 2.32.3 pyhd8ed1ab_0 conda-forge
rich 13.9.4 pypi_0 pypi
s3transfer 0.10.4 pypi_0 pypi
safetensors 0.4.5 pypi_0 pypi
sentencepiece 0.2.0 pypi_0 pypi
setuptools 75.6.0 pyhff2d567_1 conda-forge
shtab 1.7.1 pypi_0 pypi
six 1.16.0 pypi_0 pypi
sympy 1.13.1 pypi_0 pypi
tbb 2021.13.0 hceb3a55_1 conda-forge
tensorboard 2.18.0 pypi_0 pypi
tensorboard-data-server 0.7.2 pypi_0 pypi
termcolor 2.5.0 pypi_0 pypi
timm 1.0.11 pypi_0 pypi
tk 8.6.13 noxft_h4845f30_101 conda-forge
tokenizers 0.20.3 pypi_0 pypi
torchaudio 2.5.1 py311_cu121 pytorch
torchtriton 3.1.0 py311 pytorch
torchvision 0.20.1 py311_cu121 pytorch
tqdm 4.67.1 pypi_0 pypi
transformers 4.45.1 pypi_0 pypi
trl 0.10.1 pypi_0 pypi
typeguard 4.4.1 pypi_0 pypi
typing_extensions 4.12.2 pyha770c72_0 conda-forge
tyro 0.9.2 pypi_0 pypi
tzdata 2024.2 pypi_0 pypi
urllib3 2.2.3 pyhd8ed1ab_0 conda-forge
werkzeug 3.1.3 pypi_0 pypi
wheel 0.45.1 pyhd8ed1ab_0 conda-forge
xorg-libxau 1.0.11 hb9d3cd8_1 conda-forge
xorg-libxdmcp 1.1.5 hb9d3cd8_0 conda-forge
xxhash 3.5.0 pypi_0 pypi
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
yapf 0.43.0 pypi_0 pypi
yarl 1.18.3 pypi_0 pypi
zlib 1.2.13 h4ab18f5_6 conda-forge
zstandard 0.23.0 py311hbc35293_1 conda-forge
zstd 1.5.6 ha6fb4c9_0 conda-forge
Error traceback
No response
请问你解决了吗
请问解决了吗
训练的时候建议不要开TP,最好是用FSDP或者DeepSpeed Zero3来进行训练