jax-gcn
jax-gcn copied to clipboard
Fix index error
Hey there!
Thanks for this repository (and the accompanying blog post), really helpful to learn more about JAX and graph neural networks!
When I ran python train.py
, I get an error stating the following:
Starting training...
Traceback (most recent call last):
File "train.py", line 90, in <module>
opt_state = update(epoch, opt_state, batch)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 371, in f_jitted
return cpp_jitted_f(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 284, in cache_miss
donated_invars=donated_invars)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 598, in process_call
return primitive.impl(f, *tracers, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 570, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 251, in memoized_fun
ans = call(fun, *args)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 645, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "train.py", line 81, in update
return opt_update(i, grad(loss)(params, batch), opt_state)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 706, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 769, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 1846, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py", line 114, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py", line 101, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 516, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 371, in f_jitted
return cpp_jitted_f(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 284, in cache_miss
donated_invars=donated_invars)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py", line 318, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 193, in process_call
f, in_pvals, app, instantiate=False)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 310, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1085, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "train.py", line 20, in loss
ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 525, in __getitem__
def __getitem__(self, idx): return self.aval._getitem(self, idx)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py", line 4104, in _rewriting_take
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py", line 4160, in _split_index_for_jit
idx = _eliminate_deprecated_list_indexing(idx)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py", line 4408, in _eliminate_deprecated_list_indexing
raise TypeError(msg)
TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.
The reason for this is that you're indexing the tensors in train.py
using Python ranges and lists, and the authors of JAX have decided to deprecate this as can be seen here. It can be fixed by using jnp.array(idx)
instead of idx
(this PR does this for idx_train
, idx_val
and idx_test
).
The reason I'm running this is because I would like to implement the same thing as you but using FLAX, the high-level API on top of JAX for deep learning. I have a notebook which you can run (training GCN on Cora): https://drive.google.com/file/d/1D-GwuZH19p19RjnxuDbw4GsrmM3bCyNp/view?usp=sharing
What's weird is that my initial loss is the same as yours (1.94) but after that, the loss stays 1.81 and doesn't change anymore. I'm using the same optimizer and learning rate. Would be great if you could take a look!
Thanks for noticing this, I haven't run the code since I implemented it and I was not aware of the error.
The issue with the loss stuck at 1.81 is for this PR after the fix or for the notebook using FLAX?
I'll take a look at the notebook, I've used Flax a bit and maybe I can spot something weird.
Yeah I meant my notebook. Your code works just fine with my PR:
Iter 199/200 (0.0013 s) train_loss: 0.4010, train_acc: 0.9857, val_loss: 0.9357, val_acc: 0.7880
Test set acc: 0.8130000233650208
But my notebook seems to have an issue.
Hi @gcucurull,
my GCN implementation in FLAX is working. However, I've also added the Graph Attention Network, but there's probably a mistake in there, could you please have a look?
I've created a copy of the notebook which you can edit: https://colab.research.google.com/drive/1YbRUgrnWhbMV2pJJMtpIjL51vAp4XVfW?usp=sharing
Thank you!