easy-neural-ode
easy-neural-ode copied to clipboard
Error running latent_ode.py
tried running the script on physionet data and get the following error after a few iterations, can you comment on this and also a bit more on what is the expected output:
TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
Upon running it again, it would just hang here:
python latent_ode.py --reg r3 --lam 1e-2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
~/.conda/envs/Neural_ODE/lib/python3.8/site-packages/jax/_src/random.py:511: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation
warnings.warn(msg, FutureWarning)
conda environment:
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 1_gnu conda-forge
absl-py 0.13.0 <pip>
backcall 0.2.0 <pip>
ca-certificates 2021.5.30 ha878542_0 conda-forge
certifi 2021.5.30 py38h578d9bd_0 conda-forge
cycler 0.10.0 <pip>
Cython 0.29.19 <pip>
debugpy 1.3.0 <pip>
dm-haiku 0.0.5.dev0 <pip>
flatbuffers 2.0 <pip>
future 0.18.2 <pip>
ipykernel 6.0.0 <pip>
ipython 7.25.0 <pip>
ipython-genutils 0.2.0 <pip>
jax 0.2.17 <pip>
jaxlib 0.1.68 <pip>
jedi 0.18.0 <pip>
jmp 0.0.2 <pip>
joblib 0.15.1 <pip>
jupyter-client 6.1.12 <pip>
jupyter-core 4.7.1 <pip>
kiwisolver 1.2.0 <pip>
ld_impl_linux-64 2.36.1 hea4e1c9_0 conda-forge
libffi 3.3 h58526e2_2 conda-forge
libgcc-ng 9.3.0 h2828fa1_19 conda-forge
libgomp 9.3.0 h2828fa1_19 conda-forge
libstdcxx-ng 9.3.0 h6de172a_19 conda-forge
matplotlib 3.2.1 <pip>
matplotlib-inline 0.1.2 <pip>
ncurses 6.2 h58526e2_4 conda-forge
numpy 1.21.0 <pip>
openssl 1.1.1k h7f98852_0 conda-forge
opt-einsum 3.3.0 <pip>
parso 0.8.2 <pip>
pexpect 4.8.0 <pip>
phate 1.0.7 <pip>
pickleshare 0.7.5 <pip>
pip 21.1.3 pyhd8ed1ab_0 conda-forge
POT 0.7.0 <pip>
prompt-toolkit 3.0.19 <pip>
ptyprocess 0.7.0 <pip>
Pygments 2.9.0 <pip>
pyparsing 2.4.7 <pip>
python 3.8.10 h49503c6_1_cpython conda-forge
python-dateutil 2.8.1 <pip>
python_abi 3.8 2_cp38 conda-forge
pyzmq 22.1.0 <pip>
readline 8.1 h46c0cb4_0 conda-forge
s-gd2 1.8 <pip>
scikit-learn 0.23.1 <pip>
scipy 1.4.1 <pip>
setuptools 49.6.0 py38h578d9bd_3 conda-forge
six 1.15.0 <pip>
sklearn 0.0 <pip>
sqlite 3.36.0 h9cd32fc_0 conda-forge
tabulate 0.8.9 <pip>
threadpoolctl 2.1.0 <pip>
tk 8.6.10 h21135ba_1 conda-forge
torch 1.5.0 <pip>
torchdiffeq 0.0.1 <pip>
tornado 6.1 <pip>
traitlets 5.0.5 <pip>
wcwidth 0.2.5 <pip>
wheel 0.36.2 pyhd3deb0d_0 conda-forge
xz 5.2.5 h516909a_1 conda-forge
zlib 1.2.11 h516909a_1010 conda-forge
Bump for this. Exactly same error.
Hello! Sorry for the delayed reply. I'm having some trouble reproducing this error actually. I used the preprocessed data available in the release. My conda environment export is:
channels:
- defaults
dependencies:
- ca-certificates=2021.7.5=hecd8cb5_1
- certifi=2021.5.30=py38hecd8cb5_0
- libcxx=12.0.0=h2f01273_0
- libffi=3.3=hb1e8313_2
- ncurses=6.2=h0a44026_1
- openssl=1.1.1l=h9ed2024_0
- python=3.8.11=h88f2d9e_1
- readline=8.1=h9ed2024_0
- setuptools=58.0.4=py38hecd8cb5_0
- sqlite=3.36.0=hce871da_0
- tk=8.6.10=hb0a8c7a_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h1de35cc_0
- zlib=1.2.11=h1de35cc_3
- pip:
- absl-py==0.14.0
- dm-haiku==0.0.5.dev0
- flatbuffers==2.0
- jax==0.2.20
- jaxlib==0.1.71
- jmp==0.0.2
- numpy==1.21.2
- opt-einsum==3.3.0
- pip==21.2.4
- scipy==1.7.1
- six==1.16.0
- tabulate==0.8.9
I ran the command python latent_ode.py --reg r2 --lam 1e-2 --test_freq 1
on my laptop and ran python latent_ode.py --reg r2 --lam 1e-2 --test_freq 1
, so far I have after ~10 minutes of running on my macbook:
Iter 0001 | Loss 798.138111 | Likelihood -808.377092 | KL 2.490536 | MSE 0.165348 | Enc. r 0.000000 | Dec. r 0.001278 | Enc. NFE 0.000000 | Dec. NFE 31.824688
Iter 0002 | Loss 551.387941 | Likelihood -566.549929 | KL 1.965105 | MSE 0.116983 | Enc. r 0.000000 | Dec. r 0.005880 | Enc. NFE 0.000000 | Dec. NFE 31.839688
Iter 0003 | Loss 495.621342 | Likelihood -497.331870 | KL 1.669389 | MSE 0.103139 | Enc. r 0.000000 | Dec. r 0.020152 | Enc. NFE 0.000000 | Dec. NFE 34.642188
Iter 0004 | Loss 332.830424 | Likelihood -335.500099 | KL 1.934213 | MSE 0.070773 | Enc. r 0.000000 | Dec. r 0.016797 | Enc. NFE 0.000000 | Dec. NFE 32.999062
Iter 0005 | Loss 222.494621 | Likelihood -230.846079 | KL 2.180931 | MSE 0.049842 | Enc. r 0.000000 | Dec. r 0.010237 | Enc. NFE 0.000000 | Dec. NFE 35.735312
In particular, I used r2
since it uses less memory. Using r3
is possible, but I typically only ran this on a remote cluster where I had access to more RAM.
When you ran the first time and got the error TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
, was this after you ran the data processing code yourself?
In summary, my suggestions are:
- See if my conda environment is different than yours and if this fixes this error.
- Set
--test_freq 1
to confirm code is running (the default is--test_freq 640
- Try
--reg r2
since it's faster and uses less memory - Try running on a remote machine with more RAM, especially if you want to use
--reg r3
, e.g. try Google Collab?
Please let me know if any of this is helpful, or if you have any other issues!