trax icon indicating copy to clipboard operation
trax copied to clipboard

Trax got error with my own model TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.

Open leealex0201 opened this issue 4 years ago • 2 comments

Description

Hello. I am taking Coursera NLP attention model course, with Trax package. During the course assignment, the task is to train model and evaluate the model performance. Course instruction makes you write a training code script for test and provide you with the model save file because training takes long time for students.

I was trying to train my own model by using my own machine and running trax's "training_loop" 10,000 times. Following is the code script and performance.

!rm -f ~/model/model.pkl.gz loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream) loop.run(10000)

where trainig_loop code definition is as follow:

training_loop(TransformerLM, train_gen, eval_gen, output_dir = "~/model")

and TransformerLM is a model.

So I got the eval accuracy score 0.233. To evaluate my model's performance, I changed model mode to "eval" and went ahead with testing my model. Again, evaluating model is a part of course and the instructor prepared a code block that lets me to load the provided model (which is supposed to be trained by the instructor and supposed to be way better than mine). I blocked that loading part and proceeded with the evaluation part as following.

sentence_test_nxt_symbl = "I want to fly in the sky." detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])

where "detokenize" function detokenize the tensor to string, "next_symbol" function returs the next symbol (tensor) for a given sentence (tokenized sentences), and the "tokenize" function tokenizes the string to tensor.

When I run this code, I saw following error.

---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-38-9431ff725bfd> in <module>
      1 # Test it out!
      2 sentence_test_nxt_symbl = "I want to fly in the sky."
----> 3 detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])

<ipython-input-37-ec02cc9ac335> in next_symbol(cur_output_tokens, model)
     24 
     25     # model expects a tuple containing two padded tensors (with batch)
---> 26     output, _ = model((padded_with_batch, padded_with_batch))
     27     # HINT: output has shape (1, padded_length, vocab_size)
     28     # To get log_probs you need to index output with 0 in the first dim

~/anaconda3/lib/python3.8/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
    171       self.state = state  # Needed if the model wasn't fully initialized.
    172     state = self.state
--> 173     outputs, new_state = self.pure_fn(x, weights, state, rng)
    174     self.state = new_state
    175     self.weights = weights

~/anaconda3/lib/python3.8/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
    520       # Skipping 3 lines as it's always the uninteresting internal call.
    521       name, trace = self._name, _short_traceback(skip=3)
--> 522       raise LayerError(name, 'pure_fn',
    523                        self._caller, signature(x), trace) from None
    524 

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-30-412c88842307>, line 50
  layer input shapes: (ShapeDtype{shape:(1, 16), dtype:int64}, ShapeDtype{shape:(1, 16), dtype:int64})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Embedding_33300_512 (in pure_fn):
  layer created in file [...]/<ipython-input-30-412c88842307>, line 38
  layer input shapes: ShapeDtype{shape:(1, 16), dtype:int32}

  File [...]/trax/layers/core.py, line 150, in forward
    return jnp.take(self.weights, x, axis=0)

  File [...]/_src/numpy/lax_numpy.py, line 4077, in take
    return lax.gather(a, indices[..., None], dimension_numbers=dnums,

  File [...]/_src/lax/lax.py, line 872, in gather
    return gather_p.bind(

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/site-packages/jax/core.py, line 628, in process_primitive
    return primitive.impl(*tracers, **params)

  File [...]/jax/interpreters/xla.py, line 238, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

  File [...]/jax/_src/util.py, line 198, in wrapper
    return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)

  File [...]/jax/_src/util.py, line 191, in cached
    return f(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 263, in xla_primitive_callable
    aval_out = prim.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
    shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)

  File [...]/_src/lax/lax.py, line 4234, in _gather_shape_rule
    raise TypeError(f"Slice size at index {i} in gather op is out of range, "

TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.

What puzzles me is that it works when I used provided saved (provided) model, it worked fine, but this error happens only when I used my own model.

Could you please let me know how to resolve this issue?

Thank you!

Environment information

OS: Windows WSL2 Ubuntu 18.02

$ pip freeze | grep trax
trax==1.3.4

$ pip freeze | grep tensor
mesh-tensorflow==0.1.18
tensor2tensor==1.15.7
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-addons==0.12.1
tensorflow-datasets==4.2.0
tensorflow-estimator==2.4.0
tensorflow-gan==2.0.0
tensorflow-gpu==2.4.1
tensorflow-hub==0.11.0
tensorflow-metadata==0.27.0
tensorflow-probability==0.7.0
tensorflow-text==2.4.3

$ pip freeze | grep jax
jax==0.2.9
jaxlib==0.1.59+cuda110

$ python -V
Python 3.8.5

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

leealex0201 avatar Feb 17 '21 15:02 leealex0201

Replace model with training_loop.model. You are facing this issue because you are using model object which is neither initiated nor trained a trained instance of model can be accessed using training_loop.model

vaibhavtmnit avatar Feb 20 '21 04:02 vaibhavtmnit

Thanks @vaibhavtmnit, even I got the same error initially, but after changing in the code what you suggested, it worked.

Smriti1996 avatar Jun 21 '21 06:06 Smriti1996