motion-diffusion-model icon indicating copy to clipboard operation
motion-diffusion-model copied to clipboard

Training MDM

Open rainofmine opened this issue 2 years ago • 20 comments

I tried to train MDM on HumanML3D with the provided training script but the loss shows Nan. And the predicted result is not correct. Is anything wrong?

By the way, error occurs when running training with --eval_during_training or --train_platform_type {ClearmlPlatform, TensorboardPlatform}.

rainofmine avatar Oct 17 '22 06:10 rainofmine

Hi @rainofmine,

  1. Can you please provide the exact command line that you have used?
  2. Can you please provide the standard output of the training? Also: The evaluation (induced by eval_during_training) takes ~90 minutes. During those 90 minutes, no loss is outputted. However, the loss at iteration 0 should be outputted before evaluation and should not be Nan.

sigal-raab avatar Oct 18 '22 08:10 sigal-raab

@sigal-raab I just run the provided script. python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset humanml

The training log image

rainofmine avatar Oct 18 '22 10:10 rainofmine

@rainofmine, Thank you for the information. According to the command-line above, the nan loss happens also without --eval_during_training or --train_platform_type {ClearmlPlatform, TensorboardPlatform}. Is that correct? We cannot reconstruct the problem in our environment so I'd like to find out the differences:

  1. What operating system are you using and which version?
  2. What cuda version are you using?
  3. Please send us the details regarding the conda environment in which you run the training: run "conda env export > environment.yml" and then attach the file environment.yml to your answer.

sigal-raab avatar Oct 18 '22 17:10 sigal-raab

@sigal-raab

  • Ubuntu 18.04
  • cuda 10.1
  • Pytorch 1.7.1

The environment:

name: mdm channels:

  • pytorch
  • conda-forge
  • defaults dependencies:
  • _libgcc_mutex=0.1=main
  • _openmp_mutex=5.1=1_gnu
  • beautifulsoup4=4.11.1=pyha770c72_0
  • blas=1.0=mkl
  • brotlipy=0.7.0=py37h540881e_1004
  • ca-certificates=2022.9.24=ha878542_0
  • catalogue=2.0.8=py37h89c1867_0
  • certifi=2022.9.24=pyhd8ed1ab_0
  • cffi=1.15.1=py37h74dc2b5_0
  • charset-normalizer=2.1.1=pyhd8ed1ab_0
  • colorama=0.4.5=pyhd8ed1ab_0
  • cryptography=35.0.0=py37hf1a17b8_2
  • cudatoolkit=11.0.221=h6bb024c_0
  • cycler=0.11.0=pyhd3eb1b0_0
  • cymem=2.0.6=py37hd23a5d3_3
  • dataclasses=0.8=pyhc8e2a94_3
  • dbus=1.13.18=hb2f20db_0
  • expat=2.4.9=h6a678d5_0
  • fftw=3.3.9=h27cfd23_1
  • filelock=3.8.0=pyhd8ed1ab_0
  • fontconfig=2.13.1=h6c09931_0
  • freetype=2.11.0=h70c0345_0
  • gdown=4.5.1=pyhd8ed1ab_0
  • giflib=5.2.1=h7b6447c_0
  • glib=2.69.1=h4ff587b_1
  • gst-plugins-base=1.14.0=h8213a91_2
  • gstreamer=1.14.0=h28cd5cc_2
  • h5py=3.7.0=py37h737f45e_0
  • hdf5=1.10.6=h3ffc7dd_1
  • icu=58.2=he6710b0_3
  • idna=3.4=pyhd8ed1ab_0
  • intel-openmp=2021.4.0=h06a4308_3561
  • jinja2=3.1.2=pyhd8ed1ab_1
  • jpeg=9b=h024ee3a_2
  • kiwisolver=1.4.2=py37h295c915_0
  • langcodes=3.3.0=pyhd8ed1ab_0
  • lcms2=2.12=h3be6417_0
  • ld_impl_linux-64=2.38=h1181459_1
  • libffi=3.3=he6710b0_2
  • libgcc-ng=11.2.0=h1234567_1
  • libgfortran-ng=11.2.0=h00389a5_1
  • libgfortran5=11.2.0=h1234567_1
  • libgomp=11.2.0=h1234567_1
  • libpng=1.6.37=hbc83047_0
  • libstdcxx-ng=11.2.0=h1234567_1
  • libtiff=4.1.0=h2733197_1
  • libuuid=1.0.3=h7f8727e_2
  • libuv=1.40.0=h7b6447c_0
  • libwebp=1.2.0=h89dd481_0
  • libxcb=1.15=h7f8727e_0
  • libxml2=2.9.14=h74e7548_0
  • lz4-c=1.9.3=h295c915_1
  • markupsafe=2.1.1=py37h540881e_1
  • matplotlib=3.1.3=py37_0
  • matplotlib-base=3.1.3=py37hef1b27d_0
  • mkl=2021.4.0=h06a4308_640
  • mkl-service=2.4.0=py37h7f8727e_0
  • mkl_fft=1.3.1=py37hd3c417c_0
  • mkl_random=1.2.2=py37h51133e4_0
  • ncurses=6.3=h5eee18b_3
  • ninja=1.10.2=h06a4308_5
  • ninja-base=1.10.2=hd09550d_5
  • numpy=1.21.5=py37h6c91a56_3
  • numpy-base=1.21.5=py37ha15fc14_3
  • openssl=1.1.1q=h7f8727e_0
  • packaging=21.3=pyhd8ed1ab_0
  • pathy=0.6.2=pyhd8ed1ab_0
  • pcre=8.45=h295c915_0
  • pillow=9.2.0=py37hace64e9_1
  • pip=22.2.2=py37h06a4308_0
  • pycparser=2.21=pyhd8ed1ab_0
  • pydantic=1.8.2=py37h5e8e339_2
  • pyopenssl=22.0.0=pyhd8ed1ab_1
  • pyparsing=3.0.9=py37h06a4308_0
  • pyqt=5.9.2=py37h05f1152_2
  • pysocks=1.7.1=py37h89c1867_5
  • python=3.7.13=h12debd9_0
  • python-dateutil=2.8.2=pyhd3eb1b0_0
  • python_abi=3.7=2_cp37m
  • pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0
  • qt=5.9.7=h5867ecd_1
  • readline=8.1.2=h7f8727e_1
  • requests=2.28.1=pyhd8ed1ab_1
  • scipy=1.7.3=py37h6c91a56_2
  • setuptools=63.4.1=py37h06a4308_0
  • shellingham=1.5.0=pyhd8ed1ab_0
  • sip=4.19.8=py37hf484d3e_0
  • six=1.16.0=pyhd3eb1b0_1
  • smart_open=5.2.1=pyhd8ed1ab_0
  • soupsieve=2.3.2.post1=pyhd8ed1ab_0
  • spacy=3.3.1=py37h79cecc1_0
  • spacy-legacy=3.0.10=pyhd8ed1ab_0
  • spacy-loggers=1.0.3=pyhd8ed1ab_0
  • sqlite=3.39.3=h5082296_0
  • tk=8.6.12=h1ccaba5_0
  • torchaudio=0.7.2=py37
  • torchvision=0.8.2=py37_cu110
  • tornado=6.2=py37h5eee18b_0
  • tqdm=4.64.1=py37h06a4308_0
  • trimesh=3.15.3=pyh1a96a4e_0
  • typer=0.4.2=pyhd8ed1ab_0
  • wheel=0.37.1=pyhd3eb1b0_0
  • xz=5.2.6=h5eee18b_0
  • zipp=3.8.1=pyhd8ed1ab_0
  • zlib=1.2.12=h5eee18b_3
  • zstd=1.4.9=haebb681_0
  • pip:
    • attrs==22.1.0
    • blis==0.7.8
    • blobfile==2.0.0
    • chumpy==0.70
    • clearml==1.7.1
    • click==8.1.3
    • clip==1.0
    • confection==0.0.2
    • en-core-web-sm==3.3.0
    • ftfy==6.1.1
    • furl==2.1.3
    • future==0.18.2
    • importlib-metadata==5.0.0
    • importlib-resources==5.10.0
    • jsonschema==4.16.0
    • lxml==4.9.1
    • murmurhash==1.0.8
    • orderedmultidict==1.0.1
    • pathlib2==2.3.7.post1
    • pkgutil-resolve-name==1.3.10
    • preshed==3.0.7
    • psutil==5.9.2
    • pycryptodomex==3.15.0
    • pyjwt==2.4.0
    • pyrsistent==0.18.1
    • pyyaml==6.0
    • regex==2022.9.13
    • smplx==0.1.28
    • srsly==2.4.4
    • thinc==8.0.17
    • typing-extensions==4.1.1
    • urllib3==1.26.12
    • wasabi==0.10.1
    • wcwidth==0.2.5

rainofmine avatar Oct 19 '22 06:10 rainofmine

@rainofmine, Here are my environment details:

  • Ubuntu 18.04
  • Cuda 11.1
  • Pytorch 1.12.1+cu102
  • packages as in the environment.yml file in this repo.

From a comparison with your env, I guess the greatest difference is the Cuda version. 10.1 is old. Can you upgrade to Cuda 11.1 (or newer) and let me know whether the problem is solved? If this does not help, then try adjusting the pytorch version. Also, please let me know which python version you are using. Our code has been tested on python 3.7. Next in the order of priorities are the pip packages versions which are slightly different than the ones in our environment.yml file.

sigal-raab avatar Oct 19 '22 08:10 sigal-raab

@rainofmine, I noticed that you closed the issue. Is it because it was solved? If so, what was the solution? I am asking because maybe others will encounter the same problem and your answer will help them.

sigal-raab avatar Oct 20 '22 09:10 sigal-raab

Hi, I encounter the same problem as you, did you find any solution? Thx @rainofmine

Ying156209 avatar Nov 22 '22 04:11 Ying156209

Hi, I encounter the same problem as you, did you find any solution? Thx @rainofmine The same question, how do you solve it?Thx!

Kai-0515 avatar Feb 19 '23 01:02 Kai-0515

@sigal-raab

  • Ubuntu 18.04
  • cuda 10.1
  • Pytorch 1.7.1

The environment:

name: mdm channels:

  • pytorch

  • conda-forge

  • defaults dependencies:

  • _libgcc_mutex=0.1=main

  • _openmp_mutex=5.1=1_gnu

  • beautifulsoup4=4.11.1=pyha770c72_0

  • blas=1.0=mkl

  • brotlipy=0.7.0=py37h540881e_1004

  • ca-certificates=2022.9.24=ha878542_0

  • catalogue=2.0.8=py37h89c1867_0

  • certifi=2022.9.24=pyhd8ed1ab_0

  • cffi=1.15.1=py37h74dc2b5_0

  • charset-normalizer=2.1.1=pyhd8ed1ab_0

  • colorama=0.4.5=pyhd8ed1ab_0

  • cryptography=35.0.0=py37hf1a17b8_2

  • cudatoolkit=11.0.221=h6bb024c_0

  • cycler=0.11.0=pyhd3eb1b0_0

  • cymem=2.0.6=py37hd23a5d3_3

  • dataclasses=0.8=pyhc8e2a94_3

  • dbus=1.13.18=hb2f20db_0

  • expat=2.4.9=h6a678d5_0

  • fftw=3.3.9=h27cfd23_1

  • filelock=3.8.0=pyhd8ed1ab_0

  • fontconfig=2.13.1=h6c09931_0

  • freetype=2.11.0=h70c0345_0

  • gdown=4.5.1=pyhd8ed1ab_0

  • giflib=5.2.1=h7b6447c_0

  • glib=2.69.1=h4ff587b_1

  • gst-plugins-base=1.14.0=h8213a91_2

  • gstreamer=1.14.0=h28cd5cc_2

  • h5py=3.7.0=py37h737f45e_0

  • hdf5=1.10.6=h3ffc7dd_1

  • icu=58.2=he6710b0_3

  • idna=3.4=pyhd8ed1ab_0

  • intel-openmp=2021.4.0=h06a4308_3561

  • jinja2=3.1.2=pyhd8ed1ab_1

  • jpeg=9b=h024ee3a_2

  • kiwisolver=1.4.2=py37h295c915_0

  • langcodes=3.3.0=pyhd8ed1ab_0

  • lcms2=2.12=h3be6417_0

  • ld_impl_linux-64=2.38=h1181459_1

  • libffi=3.3=he6710b0_2

  • libgcc-ng=11.2.0=h1234567_1

  • libgfortran-ng=11.2.0=h00389a5_1

  • libgfortran5=11.2.0=h1234567_1

  • libgomp=11.2.0=h1234567_1

  • libpng=1.6.37=hbc83047_0

  • libstdcxx-ng=11.2.0=h1234567_1

  • libtiff=4.1.0=h2733197_1

  • libuuid=1.0.3=h7f8727e_2

  • libuv=1.40.0=h7b6447c_0

  • libwebp=1.2.0=h89dd481_0

  • libxcb=1.15=h7f8727e_0

  • libxml2=2.9.14=h74e7548_0

  • lz4-c=1.9.3=h295c915_1

  • markupsafe=2.1.1=py37h540881e_1

  • matplotlib=3.1.3=py37_0

  • matplotlib-base=3.1.3=py37hef1b27d_0

  • mkl=2021.4.0=h06a4308_640

  • mkl-service=2.4.0=py37h7f8727e_0

  • mkl_fft=1.3.1=py37hd3c417c_0

  • mkl_random=1.2.2=py37h51133e4_0

  • ncurses=6.3=h5eee18b_3

  • ninja=1.10.2=h06a4308_5

  • ninja-base=1.10.2=hd09550d_5

  • numpy=1.21.5=py37h6c91a56_3

  • numpy-base=1.21.5=py37ha15fc14_3

  • openssl=1.1.1q=h7f8727e_0

  • packaging=21.3=pyhd8ed1ab_0

  • pathy=0.6.2=pyhd8ed1ab_0

  • pcre=8.45=h295c915_0

  • pillow=9.2.0=py37hace64e9_1

  • pip=22.2.2=py37h06a4308_0

  • pycparser=2.21=pyhd8ed1ab_0

  • pydantic=1.8.2=py37h5e8e339_2

  • pyopenssl=22.0.0=pyhd8ed1ab_1

  • pyparsing=3.0.9=py37h06a4308_0

  • pyqt=5.9.2=py37h05f1152_2

  • pysocks=1.7.1=py37h89c1867_5

  • python=3.7.13=h12debd9_0

  • python-dateutil=2.8.2=pyhd3eb1b0_0

  • python_abi=3.7=2_cp37m

  • pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0

  • qt=5.9.7=h5867ecd_1

  • readline=8.1.2=h7f8727e_1

  • requests=2.28.1=pyhd8ed1ab_1

  • scipy=1.7.3=py37h6c91a56_2

  • setuptools=63.4.1=py37h06a4308_0

  • shellingham=1.5.0=pyhd8ed1ab_0

  • sip=4.19.8=py37hf484d3e_0

  • six=1.16.0=pyhd3eb1b0_1

  • smart_open=5.2.1=pyhd8ed1ab_0

  • soupsieve=2.3.2.post1=pyhd8ed1ab_0

  • spacy=3.3.1=py37h79cecc1_0

  • spacy-legacy=3.0.10=pyhd8ed1ab_0

  • spacy-loggers=1.0.3=pyhd8ed1ab_0

  • sqlite=3.39.3=h5082296_0

  • tk=8.6.12=h1ccaba5_0

  • torchaudio=0.7.2=py37

  • torchvision=0.8.2=py37_cu110

  • tornado=6.2=py37h5eee18b_0

  • tqdm=4.64.1=py37h06a4308_0

  • trimesh=3.15.3=pyh1a96a4e_0

  • typer=0.4.2=pyhd8ed1ab_0

  • wheel=0.37.1=pyhd3eb1b0_0

  • xz=5.2.6=h5eee18b_0

  • zipp=3.8.1=pyhd8ed1ab_0

  • zlib=1.2.12=h5eee18b_3

  • zstd=1.4.9=haebb681_0

  • pip:

    • attrs==22.1.0
    • blis==0.7.8
    • blobfile==2.0.0
    • chumpy==0.70
    • clearml==1.7.1
    • click==8.1.3
    • clip==1.0
    • confection==0.0.2
    • en-core-web-sm==3.3.0
    • ftfy==6.1.1
    • furl==2.1.3
    • future==0.18.2
    • importlib-metadata==5.0.0
    • importlib-resources==5.10.0
    • jsonschema==4.16.0
    • lxml==4.9.1
    • murmurhash==1.0.8
    • orderedmultidict==1.0.1
    • pathlib2==2.3.7.post1
    • pkgutil-resolve-name==1.3.10
    • preshed==3.0.7
    • psutil==5.9.2
    • pycryptodomex==3.15.0
    • pyjwt==2.4.0
    • pyrsistent==0.18.1
    • pyyaml==6.0
    • regex==2022.9.13
    • smplx==0.1.28
    • srsly==2.4.4
    • thinc==8.0.17
    • typing-extensions==4.1.1
    • urllib3==1.26.12
    • wasabi==0.10.1
    • wcwidth==0.2.5

Hi, I wonder how do u solve this problem? Looking forward to your reply, thx!

Kai-0515 avatar Feb 20 '23 01:02 Kai-0515

@Kai-0515 , my advise to @rainofmine was to upgrade his Cuda version, to 11.1 or higher. He did not reply, so I don't know if it worked for him. However, he closed the issue, which may indicate of a happy solution. I re-opened the issue due to your question. If your Cuda version is relatively old, will you try installing a newer one and report the results?

sigal-raab avatar Feb 20 '23 06:02 sigal-raab

@Kai-0515 , my advise to @rainofmine was to upgrade his Cuda version, to 11.1 or higher. He did not reply, so I don't know if it worked for him. However, he closed the issue, which may indicate of a happy solution. I re-opened the issue due to your question. If your Cuda version is relatively old, will you try installing a newer one and report the results?

Thx for your reply, my cuda version is 11.4 and other settings are the same as your provides. I also try to minimize the batchsize to 8, but the loss is still NAN. For the humanml3d dataset, I process as they announced and evaluate, the dataset should be good. I wonder if there any other settings in the code which may have influences on results?

Kai-0515 avatar Feb 20 '23 06:02 Kai-0515

@Kai-0515, do you encounter the same problem when working with humanact12 or uestc? Even if those are not the datasets you want to work with, your answer may help us figure out the cause of this issue.

sigal-raab avatar Feb 20 '23 07:02 sigal-raab

@Kai-0515, do you encounter the same problem when working with humanact12 or uestc? Even if those are not the datasets you want to work with, your answer may help us figure out the cause of this issue.

I will have a try

Kai-0515 avatar Feb 20 '23 08:02 Kai-0515

@Kai-0515, do you encounter the same problem when working with humanact12 or uestc? Even if those are not the datasets you want to work with, your answer may help us figure out the cause of this issue.

I find the problem, some data in humanml3d is broken while I evaluate it using the method humanml3d provides. The broken data is between 3000-5000, I clean it and have the right result.

Kai-0515 avatar Feb 21 '23 01:02 Kai-0515

@Kai-0515 , I am glad you can work now. @GuyTevet , can you double-check the data?

sigal-raab avatar Feb 21 '23 08:02 sigal-raab

@Kai-0515 can you please open an issue in https://github.com/EricGuo5513/HumanML3D ? It is possible that there is a bug in the data pre-processing.

GuyTevet avatar Feb 22 '23 15:02 GuyTevet

@Kai-0515 Hi, I encountered the same problem. How did you find the broken data?

ShungJhon avatar Mar 03 '23 05:03 ShungJhon

I found out the broken data for me is 004355.npy and M004355.npy, using cal_mean_variance.ipynb provided by HumanMl3d. After removing them, the loss is normally computed.

ShungJhon avatar Mar 03 '23 07:03 ShungJhon

I found out the broken data for me is 004355.npy and M004355.npy, using cal_mean_variance.ipynb provided by HumanMl3d. After removing them, the loss is normally computed.

@ShungJhon Did you mean to remove the broken data from the train_set or val_set list file?

qiqiApink avatar Mar 08 '23 02:03 qiqiApink

@qiqiApink The new_joints and new_joints_vecs files in dataset/HumanML3D/

ShungJhon avatar Mar 09 '23 06:03 ShungJhon