trax icon indicating copy to clipboard operation
trax copied to clipboard

ValueError on predict mode Transformer model

Open shashank2123 opened this issue 4 years ago • 2 comments

Description

I am getting following error if i load model in predict model. it works perfectly in eval mode.

ValueError: Incompatible shapes for matmul arguments: (8, 1, 64) and (256, 64, 2048)

model definitions

model = trax.models.Transformer(
      input_vocab_size=33600,
      d_model=512, d_ff=2048, dropout = 0.1,
      n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
      max_len=2048, mode=mode) 

...

Environment information

OS:  I am using colab

$ pip freeze | grep trax
trax                          1.3.9  

$ pip freeze | grep tensor

mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.0.1
tensorflow-estimator==2.5.0
tensorflow-gcs-config==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.1.0
tensorflow-probability==0.13.0
tensorflow-text==2.5.0

$ pip freeze | grep jax
jax==0.2.17
jaxlib==0.1.69+cuda110

$ python -V
python 3.7

### For bugs: reproduction and error logs

Steps to reproduce:

...

def sampling_decode(input_sentence, model = None, temperature=0.0, vocab_file=None, vocab_dir=None):

    input_tokens = tokenize(input_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)
    
    cur_output_tokens = []
   
    cur_output = 0  
    
    EOS = 1
    
    while cur_output != EOS: 
        
        cur_output, log_prob = next_symbol(model, input_tokens, cur_output_tokens, temperature)
        
        cur_output_tokens.append(cur_output) 
    
    sentence = detokenize(cur_output_tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)

    return cur_output_tokens, log_prob, sentence

eval_point = random.choice(eval_data)

incorrect_sentence = eval_point[0]
correct_sentence = eval_point[1]

print("Incorrect sentence :- ",incorrect_sentence)
print("Correct sentence :- ",correct_sentence)

pred_token, log_prob, pred_sentence = sampling_decode(incorrect_sentence, eval_model, temperature=0.0, vocab_file=vocab_file, vocab_dir=vocab_dir)

print("Predicted sentence :- ",pred_sentence)
print("correct token :- ",pred_token)
print("log_prob :- ",log_prob)

Error logs:

... LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 390 layer input shapes: (ShapeDtype{shape:(1, 204), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Branch (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Parallel (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: (ShapeDtype{shape:(1, 1, 512), dtype:float32}, ShapeDtype{shape:(1, 1, 512), dtype:float32})

File [...]/trax/layers/combinators.py, line 211, in forward sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 566 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer DotProductCausalAttention (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 556 layer input shapes: (ShapeDtype{shape:(8, 1, 64), dtype:float32}, ShapeDtype{shape:(8, 1, 64), dtype:float32}, ShapeDtype{shape:(8, 1, 64), dtype:float32})

File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper y = forward(self, x, *args, **kwargs)

File [...]/trax/layers/attention.py, line 520, in forward q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)

File [...]/trax/layers/attention.py, line 281, in _per_head_attention dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)

File [...]/_src/numpy/lax_numpy.py, line 4211, in matmul .format(shape(a), shape(b)))

ValueError: Incompatible shapes for matmul arguments: (8, 1, 64) and (256, 64, 2048)

shashank2123 avatar Jul 16 '21 20:07 shashank2123

I'm also having the exact same problem, my config is:

`trax 1.3.8 mesh-tensorflow 0.1.19 tensor2tensor 1.15.7 tensorboard 2.6.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 tensorflow 2.5.0 tensorflow-addons 0.14.0 tensorflow-datasets 4.4.0 tensorflow-estimator 2.5.0 tensorflow-gan 2.1.0 tensorflow-hub 0.12.0 tensorflow-metadata 1.2.0 tensorflow-probability 0.7.0 tensorflow-text 2.5.0 jax 0.2.21 jaxlib 0.1.71+cuda111

OS Ubuntu 20.04.3 LTS

Python version 3.8.12 ` Have you found any solution?

David-hg avatar Nov 03 '21 16:11 David-hg

I don't get any solution I switched to pytorch that was something error in parallel computing may be like some values are not processed so that form mismatch dimensions.

On Wed, Nov 3, 2021, 10:18 PM Davidh @.***> wrote:

I'm also having the exact same problem, my config is:

`trax 1.3.8 mesh-tensorflow 0.1.19 tensor2tensor 1.15.7 tensorboard 2.6.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 tensorflow 2.5.0 tensorflow-addons 0.14.0 tensorflow-datasets 4.4.0 tensorflow-estimator 2.5.0 tensorflow-gan 2.1.0 tensorflow-hub 0.12.0 tensorflow-metadata 1.2.0 tensorflow-probability 0.7.0 tensorflow-text 2.5.0 jax 0.2.21 jaxlib 0.1.71+cuda111

OS Ubuntu 20.04.3 LTS

Python version 3.8.12 ` Have you found any solution?

— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://github.com/google/trax/issues/1670#issuecomment-959719024, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKXAPQAETHHVSYZYKEEAL63UKFRXPANCNFSM5AQGUXYQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

shashank2123 avatar Nov 03 '21 17:11 shashank2123