trax
trax copied to clipboard
ValueError on predict mode Transformer model
Description
I am getting following error if i load model in predict model. it works perfectly in eval mode.
ValueError: Incompatible shapes for matmul arguments: (8, 1, 64) and (256, 64, 2048)
model definitions
model = trax.models.Transformer(
input_vocab_size=33600,
d_model=512, d_ff=2048, dropout = 0.1,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=2048, mode=mode)
...
Environment information
OS: I am using colab
$ pip freeze | grep trax
trax 1.3.9
$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.0.1
tensorflow-estimator==2.5.0
tensorflow-gcs-config==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.1.0
tensorflow-probability==0.13.0
tensorflow-text==2.5.0
$ pip freeze | grep jax
jax==0.2.17
jaxlib==0.1.69+cuda110
$ python -V
python 3.7
### For bugs: reproduction and error logs
Steps to reproduce:
...
def sampling_decode(input_sentence, model = None, temperature=0.0, vocab_file=None, vocab_dir=None):
input_tokens = tokenize(input_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)
cur_output_tokens = []
cur_output = 0
EOS = 1
while cur_output != EOS:
cur_output, log_prob = next_symbol(model, input_tokens, cur_output_tokens, temperature)
cur_output_tokens.append(cur_output)
sentence = detokenize(cur_output_tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)
return cur_output_tokens, log_prob, sentence
eval_point = random.choice(eval_data)
incorrect_sentence = eval_point[0]
correct_sentence = eval_point[1]
print("Incorrect sentence :- ",incorrect_sentence)
print("Correct sentence :- ",correct_sentence)
pred_token, log_prob, pred_sentence = sampling_decode(incorrect_sentence, eval_model, temperature=0.0, vocab_file=vocab_file, vocab_dir=vocab_dir)
print("Predicted sentence :- ",pred_sentence)
print("correct token :- ",pred_token)
print("log_prob :- ",log_prob)
Error logs:
... LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 390 layer input shapes: (ShapeDtype{shape:(1, 204), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Branch (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Parallel (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: (ShapeDtype{shape:(1, 1, 512), dtype:float32}, ShapeDtype{shape:(1, 1, 512), dtype:float32})
File [...]/trax/layers/combinators.py, line 211, in forward sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer DotProductCausalAttention (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: (ShapeDtype{shape:(8, 1, 64), dtype:float32}, ShapeDtype{shape:(8, 1, 64), dtype:float32}, ShapeDtype{shape:(8, 1, 64), dtype:float32})
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/attention.py, line 520, in forward q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)
File [...]/trax/layers/attention.py, line 281, in _per_head_attention dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
File [...]/_src/numpy/lax_numpy.py, line 4211, in matmul .format(shape(a), shape(b)))
ValueError: Incompatible shapes for matmul arguments: (8, 1, 64) and (256, 64, 2048)
I'm also having the exact same problem, my config is:
`trax 1.3.8 mesh-tensorflow 0.1.19 tensor2tensor 1.15.7 tensorboard 2.6.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 tensorflow 2.5.0 tensorflow-addons 0.14.0 tensorflow-datasets 4.4.0 tensorflow-estimator 2.5.0 tensorflow-gan 2.1.0 tensorflow-hub 0.12.0 tensorflow-metadata 1.2.0 tensorflow-probability 0.7.0 tensorflow-text 2.5.0 jax 0.2.21 jaxlib 0.1.71+cuda111
OS Ubuntu 20.04.3 LTS
Python version 3.8.12 ` Have you found any solution?
I don't get any solution I switched to pytorch that was something error in parallel computing may be like some values are not processed so that form mismatch dimensions.
On Wed, Nov 3, 2021, 10:18 PM Davidh @.***> wrote:
I'm also having the exact same problem, my config is:
`trax 1.3.8 mesh-tensorflow 0.1.19 tensor2tensor 1.15.7 tensorboard 2.6.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 tensorflow 2.5.0 tensorflow-addons 0.14.0 tensorflow-datasets 4.4.0 tensorflow-estimator 2.5.0 tensorflow-gan 2.1.0 tensorflow-hub 0.12.0 tensorflow-metadata 1.2.0 tensorflow-probability 0.7.0 tensorflow-text 2.5.0 jax 0.2.21 jaxlib 0.1.71+cuda111
OS Ubuntu 20.04.3 LTS
Python version 3.8.12 ` Have you found any solution?
— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://github.com/google/trax/issues/1670#issuecomment-959719024, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKXAPQAETHHVSYZYKEEAL63UKFRXPANCNFSM5AQGUXYQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.