java-bert-predict
java-bert-predict copied to clipboard
"logits must be 2-dimensional" exception
I followed the steps exactly as mentioned. Downloaded the multi-lang model, used the python script to get the pb file. But in Java side, I get the following when I extract embeddings. For some reason, I must do this on TF 1.9 only. (It works perfectly fine from TF 1.11 onwards). Any idea how can I crack this for TF 1.9?
2020-04-27 16:34:13.450298: I tensorflow/cc/saved_model/loader.cc:291] SavedModel load for tags { serve }; Status: success. Took 1703727 microseconds. Exception in thread "main" java.lang.IllegalArgumentException: logits must be 2-dimensional [[Node: bert/encoder/layer_0/attention/self/Softmax = Softmax[T=DT_FLOAT, _output_shapes=[[1,12,128,128]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](bert/encoder/layer_0/attention/self/add)]]
PS: Thank you so much for this project. I just landed on the right Python script and Java module I needed the most.