jax-gcn icon indicating copy to clipboard operation
jax-gcn copied to clipboard

Fix index error

Open NielsRogge opened this issue 3 years ago • 3 comments

Hey there!

Thanks for this repository (and the accompanying blog post), really helpful to learn more about JAX and graph neural networks!

When I ran python train.py, I get an error stating the following:

Starting training...
Traceback (most recent call last):
  File "train.py", line 90, in <module>
    opt_state = update(epoch, opt_state, batch)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 371, in f_jitted
    return cpp_jitted_f(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 284, in cache_miss
    donated_invars=donated_invars)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 598, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 570, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 251, in memoized_fun
    ans = call(fun, *args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 645, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "train.py", line 81, in update
    return opt_update(i, grad(loss)(params, batch), opt_state)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 706, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 769, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 1846, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py", line 114, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py", line 101, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 516, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 371, in f_jitted
    return cpp_jitted_f(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 284, in cache_miss
    donated_invars=donated_invars)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py", line 318, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 193, in process_call
    f, in_pvals, app, instantiate=False)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 310, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1085, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "train.py", line 20, in loss
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 525, in __getitem__
    def __getitem__(self, idx): return self.aval._getitem(self, idx)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py", line 4104, in _rewriting_take
    treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py", line 4160, in _split_index_for_jit
    idx = _eliminate_deprecated_list_indexing(idx)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py", line 4408, in _eliminate_deprecated_list_indexing
    raise TypeError(msg)
TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.

The reason for this is that you're indexing the tensors in train.py using Python ranges and lists, and the authors of JAX have decided to deprecate this as can be seen here. It can be fixed by using jnp.array(idx) instead of idx (this PR does this for idx_train, idx_val and idx_test).

The reason I'm running this is because I would like to implement the same thing as you but using FLAX, the high-level API on top of JAX for deep learning. I have a notebook which you can run (training GCN on Cora): https://drive.google.com/file/d/1D-GwuZH19p19RjnxuDbw4GsrmM3bCyNp/view?usp=sharing

What's weird is that my initial loss is the same as yours (1.94) but after that, the loss stays 1.81 and doesn't change anymore. I'm using the same optimizer and learning rate. Would be great if you could take a look!

NielsRogge avatar Jan 03 '21 14:01 NielsRogge

Thanks for noticing this, I haven't run the code since I implemented it and I was not aware of the error.

The issue with the loss stuck at 1.81 is for this PR after the fix or for the notebook using FLAX?

I'll take a look at the notebook, I've used Flax a bit and maybe I can spot something weird.

gcucurull avatar Jan 04 '21 09:01 gcucurull

Yeah I meant my notebook. Your code works just fine with my PR:

Iter 199/200 (0.0013 s) train_loss: 0.4010, train_acc: 0.9857, val_loss: 0.9357, val_acc: 0.7880
Test set acc: 0.8130000233650208

But my notebook seems to have an issue.

NielsRogge avatar Jan 04 '21 09:01 NielsRogge

Hi @gcucurull,

my GCN implementation in FLAX is working. However, I've also added the Graph Attention Network, but there's probably a mistake in there, could you please have a look?

I've created a copy of the notebook which you can edit: https://colab.research.google.com/drive/1YbRUgrnWhbMV2pJJMtpIjL51vAp4XVfW?usp=sharing

Thank you!

NielsRogge avatar Jan 13 '21 12:01 NielsRogge