easy-bert icon indicating copy to clipboard operation
easy-bert copied to clipboard

"logits must be 2-dimensional" error on TF 1.9

Open alpsholic opened this issue 4 years ago • 3 comments

I am getting the following exception on TF 1.9 when I load a saved Bert model on JAVA. I saved the model on Python using easy-bert and loaded in Java again with easy-bert. (https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1).

I am able to load fine but when I try to extract embeddings, it throws the following. I have to be on TF 1.9 only.

020-04-16 18:00:31.003625: I tensorflow/cc/saved_model/loader.cc:291] SavedModel load for tags { serve }; Status: success. Took 1197645 microseconds.
Exception in thread "main" java.lang.IllegalArgumentException: logits must be 2-dimensional
[[Node: module_apply_tokens/bert/encoder/layer_0/attention/self/Softmax = SoftmaxT=DT_FLOAT, _output_shapes=[[?,12,?,?]], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at com.robrua.nlp.bert.Bert.embedSequence(Bert.java:252)
at .bert.TestEasyBert.main(TestEasyBert.java:17)

alpsholic avatar Apr 17 '20 01:04 alpsholic

First thing is I want to figure out if this is a version compat issue or a general bug:

Is this a problem with TF 1.9 specifically? Has it worked on previous versions but now you need to run it on 1.9 for some other reason?

If it's only on 1.9, did you bump the TF version on both the Python & Java ends?

There's a decent chance this is just fragility in how the input/output nodes in the graph are getting passed between python & java.

robrua avatar Apr 25 '20 19:04 robrua

If you can leave some more specifics about how you're obtaining, saving, and loading the model that'd be helpful in figuring out what's wrong

robrua avatar Apr 25 '20 19:04 robrua

I didn't try for lower versions of TF. It works fine from TF 1.11 onwards but not on TF 1.9. I used your simple code to save model on Python side (I printed TF version). Loaded on Java side and tried to extract embeddings again using your sample code for Java (I printed TF version here as well). Later I tried with models using your pom entries as well (i.e., without saving-in-python-then-loading-in-java).

I got this exception on both occassions.

alpsholic avatar Apr 28 '20 20:04 alpsholic