easy-neural-ode icon indicating copy to clipboard operation
easy-neural-ode copied to clipboard

Error running latent_ode.py

Open tkamthroche opened this issue 2 years ago • 2 comments

tried running the script on physionet data and get the following error after a few iterations, can you comment on this and also a bit more on what is the expected output: TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

Upon running it again, it would just hang here:

 python latent_ode.py --reg r3 --lam 1e-2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
~/.conda/envs/Neural_ODE/lib/python3.8/site-packages/jax/_src/random.py:511: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation
  warnings.warn(msg, FutureWarning)

conda environment:

_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
absl-py                   0.13.0                    <pip>
backcall                  0.2.0                     <pip>
ca-certificates           2021.5.30            ha878542_0    conda-forge
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cycler                    0.10.0                    <pip>
Cython                    0.29.19                   <pip>
debugpy                   1.3.0                     <pip>
dm-haiku                  0.0.5.dev0                <pip>
flatbuffers               2.0                       <pip>
future                    0.18.2                    <pip>
ipykernel                 6.0.0                     <pip>
ipython                   7.25.0                    <pip>
ipython-genutils          0.2.0                     <pip>
jax                       0.2.17                    <pip>
jaxlib                    0.1.68                    <pip>
jedi                      0.18.0                    <pip>
jmp                       0.0.2                     <pip>
joblib                    0.15.1                    <pip>
jupyter-client            6.1.12                    <pip>
jupyter-core              4.7.1                     <pip>
kiwisolver                1.2.0                     <pip>
ld_impl_linux-64          2.36.1               hea4e1c9_0    conda-forge
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 9.3.0               h2828fa1_19    conda-forge
libgomp                   9.3.0               h2828fa1_19    conda-forge
libstdcxx-ng              9.3.0               h6de172a_19    conda-forge
matplotlib                3.2.1                     <pip>
matplotlib-inline         0.1.2                     <pip>
ncurses                   6.2                  h58526e2_4    conda-forge
numpy                     1.21.0                    <pip>
openssl                   1.1.1k               h7f98852_0    conda-forge
opt-einsum                3.3.0                     <pip>
parso                     0.8.2                     <pip>
pexpect                   4.8.0                     <pip>
phate                     1.0.7                     <pip>
pickleshare               0.7.5                     <pip>
pip                       21.1.3             pyhd8ed1ab_0    conda-forge
POT                       0.7.0                     <pip>
prompt-toolkit            3.0.19                    <pip>
ptyprocess                0.7.0                     <pip>
Pygments                  2.9.0                     <pip>
pyparsing                 2.4.7                     <pip>
python                    3.8.10          h49503c6_1_cpython    conda-forge
python-dateutil           2.8.1                     <pip>
python_abi                3.8                      2_cp38    conda-forge
pyzmq                     22.1.0                    <pip>
readline                  8.1                  h46c0cb4_0    conda-forge
s-gd2                     1.8                       <pip>
scikit-learn              0.23.1                    <pip>
scipy                     1.4.1                     <pip>
setuptools                49.6.0           py38h578d9bd_3    conda-forge
six                       1.15.0                    <pip>
sklearn                   0.0                       <pip>
sqlite                    3.36.0               h9cd32fc_0    conda-forge
tabulate                  0.8.9                     <pip>
threadpoolctl             2.1.0                     <pip>
tk                        8.6.10               h21135ba_1    conda-forge
torch                     1.5.0                     <pip>
torchdiffeq               0.0.1                     <pip>
tornado                   6.1                       <pip>
traitlets                 5.0.5                     <pip>
wcwidth                   0.2.5                     <pip>
wheel                     0.36.2             pyhd3deb0d_0    conda-forge
xz                        5.2.5                h516909a_1    conda-forge
zlib                      1.2.11            h516909a_1010    conda-forge

tkamthroche avatar Jul 15 '21 09:07 tkamthroche

Bump for this. Exactly same error.

itamblyn avatar Aug 20 '21 18:08 itamblyn

Hello! Sorry for the delayed reply. I'm having some trouble reproducing this error actually. I used the preprocessed data available in the release. My conda environment export is:

channels:
  - defaults
dependencies:
  - ca-certificates=2021.7.5=hecd8cb5_1
  - certifi=2021.5.30=py38hecd8cb5_0
  - libcxx=12.0.0=h2f01273_0
  - libffi=3.3=hb1e8313_2
  - ncurses=6.2=h0a44026_1
  - openssl=1.1.1l=h9ed2024_0
  - python=3.8.11=h88f2d9e_1
  - readline=8.1=h9ed2024_0
  - setuptools=58.0.4=py38hecd8cb5_0
  - sqlite=3.36.0=hce871da_0
  - tk=8.6.10=hb0a8c7a_0
  - wheel=0.37.0=pyhd3eb1b0_1
  - xz=5.2.5=h1de35cc_0
  - zlib=1.2.11=h1de35cc_3
  - pip:
    - absl-py==0.14.0
    - dm-haiku==0.0.5.dev0
    - flatbuffers==2.0
    - jax==0.2.20
    - jaxlib==0.1.71
    - jmp==0.0.2
    - numpy==1.21.2
    - opt-einsum==3.3.0
    - pip==21.2.4
    - scipy==1.7.1
    - six==1.16.0
    - tabulate==0.8.9

I ran the command python latent_ode.py --reg r2 --lam 1e-2 --test_freq 1 on my laptop and ran python latent_ode.py --reg r2 --lam 1e-2 --test_freq 1, so far I have after ~10 minutes of running on my macbook:

Iter 0001 | Loss 798.138111 | Likelihood -808.377092 | KL 2.490536 | MSE 0.165348 | Enc. r 0.000000 | Dec. r 0.001278 | Enc. NFE 0.000000 | Dec. NFE 31.824688
Iter 0002 | Loss 551.387941 | Likelihood -566.549929 | KL 1.965105 | MSE 0.116983 | Enc. r 0.000000 | Dec. r 0.005880 | Enc. NFE 0.000000 | Dec. NFE 31.839688
Iter 0003 | Loss 495.621342 | Likelihood -497.331870 | KL 1.669389 | MSE 0.103139 | Enc. r 0.000000 | Dec. r 0.020152 | Enc. NFE 0.000000 | Dec. NFE 34.642188
Iter 0004 | Loss 332.830424 | Likelihood -335.500099 | KL 1.934213 | MSE 0.070773 | Enc. r 0.000000 | Dec. r 0.016797 | Enc. NFE 0.000000 | Dec. NFE 32.999062
Iter 0005 | Loss 222.494621 | Likelihood -230.846079 | KL 2.180931 | MSE 0.049842 | Enc. r 0.000000 | Dec. r 0.010237 | Enc. NFE 0.000000 | Dec. NFE 35.735312

In particular, I used r2 since it uses less memory. Using r3 is possible, but I typically only ran this on a remote cluster where I had access to more RAM.

When you ran the first time and got the error TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?, was this after you ran the data processing code yourself?

In summary, my suggestions are:

  1. See if my conda environment is different than yours and if this fixes this error.
  2. Set --test_freq 1 to confirm code is running (the default is --test_freq 640
  3. Try --reg r2 since it's faster and uses less memory
  4. Try running on a remote machine with more RAM, especially if you want to use --reg r3, e.g. try Google Collab?

Please let me know if any of this is helpful, or if you have any other issues!

jacobjinkelly avatar Sep 23 '21 17:09 jacobjinkelly