java
java copied to clipboard
How to add training = True in tensorflow-java?
The models that I have from Python works only if I provide training = True as well in Python.
for instance in Python;
import tensorflow as tf
my_model = tf.keras.models.load_model("model")
prdct = my_model(input_image, training = True)
Here, if I do not provide training = True, the values in prdct.numpy() are all "nan".
I have the same problem in tensorflow/java as well. I wonder if there is a way to give option training = True in tensorflow/java?
Here is the java code:
TFloat32 ImagePredicted = (TFloat32) sess
.runner()
.feed(input_layer, inputTensor)
.fetch(outputLayer)
.run()
.get(0);
Depending on how you've loaded your model in and how it's specified, then you might be able to feed in a boolean (scalar) tensor that specifies if it's training or not. I've done this before in using the TF 1.x bindings to change the behaviour of a dropout layer. However that's going to depend on the model you've built in Keras, and how it gets exported from Keras and loaded into TF-Java.
This is the model in Python/tensorflow (version 2.2 and 2.4.1) and Java TF2 (tensorflow-core-platform, version 0.3.1);
the details of the model can be found : here
import tensorflow as tf
def downsample(filters, size, apply_batchnorm = True):
#initializer = tf.random_normal_initializer(0., 0.02)
initializer = "he_normal"
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2D(filters, size, strides = 2, padding='same', kernel_initializer=initializer, use_bias=False))
if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result
def upsample(filters, size, apply_dropout=False):
#initializer = tf.random_normal_initializer(0., 0.02)
initializer = "he_normal"
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2DTranspose(filters, size, strides = 2, padding='same', kernel_initializer=initializer, use_bias=False))
result.add(tf.keras.layers.BatchNormalization())
if apply_dropout:
result.add(tf.keras.layers.Dropout(0.5))
result.add(tf.keras.layers.LeakyReLU())
return result
def Generator(input_shape = (128,128,1)):
inputs = tf.keras.layers.Input(input_shape)
down_stack = [
downsample(64, 4, apply_batchnorm = False),
downsample(128, 4),
downsample(256, 4),
downsample(256, 4),
downsample(512, 4),
downsample(512, 4),
downsample(512, 4),
]
up_stack = [
upsample(512, 4, apply_dropout = True),
upsample(512, 4, apply_dropout = True),
upsample(512, 4, apply_dropout = True),
upsample(256, 4),
upsample(256, 4),
upsample(128, 4),
upsample(64, 4),
]
initializer = "he_normal"
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides = 2, padding='same', kernel_initializer=initializer, activation='tanh')
x = inputs
skips = []
for down in down_stack:
x = down(x)
skips.append(x)
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip])
x = last(x)
return tf.keras.Model(inputs = inputs, outputs = x)
generator = Generator(input_shape)
after training, I simply save the model as:
generator.save("model")
then, I load it in Python as below and set training = True;
my_model = tf.keras.models.load_model("model")
prdct = my_model(input_image, training = True) #input image is (1,128,128,1) a numpy array
Now in Java (tensorflow-core-platform, version 0.3.1) I do the following.
SavedModelBundle theModel = SavedModelBundle.load("model", "serve");
Session sess = theModel.session();
// convert Buffered image to FloatArray and then to Tensor
float[][][][] floatImage = ConvertTtoFloatImage(image, imgHeight, imgWidth); // imgHeight and imgWidth are 128, 128
TFloat32 inputTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(floatImage));
// perform prediction
TFloat32 ImagePredicted = (TFloat32) sess
.runner()
.feed("serving_default_input_1", inpTensor)
.fetch("StatefulPartitionedCall")
.run()
.get(0);
Now all values in ImagePredicted are -1.
I am not sure if this could be a solution.
public void feed(String inputName, boolean[] src, long... dims) {
byte[] b = new byte[src.length];
for (int i = 0; i < src.length; i++) {
b[i] = src[i] ? (byte) 1 : (byte) 0;
}
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}
as it raises this error (The method create(Class<Boolean>, long[], ByteBuffer) is undefined for the type Tensor) for Tensor.create, I could not test it.
You should use TBool.class rather than Boolean.class. However I'm not sure that that will achieve what you want. The effect of the training=true flag in Keras depends on the layers used, and how the model has been constructed (see the call definition here - https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/engine/training.py#L428). So if that eventually unfolds into a graph which expects a boolean tensor with a specific name then you can do that (and it tended to in TF 1.x). However Keras is more tightly integrated into TF 2.x and many of the layers have more complex behaviour. So you'll need to trace through what your Python code does and figure out what the effect of that flag is for your specific model.
Thanks for the explanation, now I could understand the logic of feed, thus one can feed several parameters, separately. Coming from Python to Java, for a beginner is really difficult to understand the logic in Java. I will write a detailed answer again if I solve the problem. Again thanks for your explanation.
The feeding logic is pretty much exactly how it worked in tf 1. We're working on function support which will line up with how tf.function works in tf 2, this is already available for inference using saved model bundle in TF Java.
Unfortunately, I have not found how to feed in the Boolean in my model.
Here is the Operations in the graph and I cannot find one that might need boolean feed.
<VarHandleOp 'conv2d_transpose_7/kernel'>
<ReadVariableOp 'conv2d_transpose_7/kernel/Read/ReadVariableOp'>
<VarHandleOp 'conv2d_transpose_7/bias'>
<ReadVariableOp 'conv2d_transpose_7/bias/Read/ReadVariableOp'>
<VarHandleOp 'sequential/conv2d/kernel'>
<ReadVariableOp 'sequential/conv2d/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential/p_re_lu/alpha'>
<ReadVariableOp 'sequential/p_re_lu/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/conv2d_1/kernel'>
<ReadVariableOp 'sequential_1/conv2d_1/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/gamma'>
<ReadVariableOp 'sequential_1/batch_normalization/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/beta'>
<ReadVariableOp 'sequential_1/batch_normalization/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/moving_mean'>
<ReadVariableOp 'sequential_1/batch_normalization/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/moving_variance'>
<ReadVariableOp 'sequential_1/batch_normalization/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/p_re_lu_1/alpha'>
<ReadVariableOp 'sequential_1/p_re_lu_1/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/conv2d_2/kernel'>
<ReadVariableOp 'sequential_2/conv2d_2/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/gamma'>
<ReadVariableOp 'sequential_2/batch_normalization_1/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/beta'>
<ReadVariableOp 'sequential_2/batch_normalization_1/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/moving_mean'>
<ReadVariableOp 'sequential_2/batch_normalization_1/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/moving_variance'>
<ReadVariableOp 'sequential_2/batch_normalization_1/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/p_re_lu_2/alpha'>
<ReadVariableOp 'sequential_2/p_re_lu_2/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/conv2d_3/kernel'>
<ReadVariableOp 'sequential_3/conv2d_3/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/gamma'>
<ReadVariableOp 'sequential_3/batch_normalization_2/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/beta'>
<ReadVariableOp 'sequential_3/batch_normalization_2/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/moving_mean'>
<ReadVariableOp 'sequential_3/batch_normalization_2/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/moving_variance'>
<ReadVariableOp 'sequential_3/batch_normalization_2/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/p_re_lu_3/alpha'>
<ReadVariableOp 'sequential_3/p_re_lu_3/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/conv2d_4/kernel'>
<ReadVariableOp 'sequential_4/conv2d_4/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/gamma'>
<ReadVariableOp 'sequential_4/batch_normalization_3/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/beta'>
<ReadVariableOp 'sequential_4/batch_normalization_3/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/moving_mean'>
<ReadVariableOp 'sequential_4/batch_normalization_3/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/moving_variance'>
<ReadVariableOp 'sequential_4/batch_normalization_3/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/p_re_lu_4/alpha'>
<ReadVariableOp 'sequential_4/p_re_lu_4/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/conv2d_5/kernel'>
<ReadVariableOp 'sequential_5/conv2d_5/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/gamma'>
<ReadVariableOp 'sequential_5/batch_normalization_4/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/beta'>
<ReadVariableOp 'sequential_5/batch_normalization_4/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/moving_mean'>
<ReadVariableOp 'sequential_5/batch_normalization_4/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/moving_variance'>
<ReadVariableOp 'sequential_5/batch_normalization_4/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/p_re_lu_5/alpha'>
<ReadVariableOp 'sequential_5/p_re_lu_5/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/conv2d_6/kernel'>
<ReadVariableOp 'sequential_6/conv2d_6/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/gamma'>
<ReadVariableOp 'sequential_6/batch_normalization_5/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/beta'>
<ReadVariableOp 'sequential_6/batch_normalization_5/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/moving_mean'>
<ReadVariableOp 'sequential_6/batch_normalization_5/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/moving_variance'>
<ReadVariableOp 'sequential_6/batch_normalization_5/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/p_re_lu_6/alpha'>
<ReadVariableOp 'sequential_6/p_re_lu_6/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/conv2d_transpose/kernel'>
<ReadVariableOp 'sequential_7/conv2d_transpose/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/gamma'>
<ReadVariableOp 'sequential_7/batch_normalization_6/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/beta'>
<ReadVariableOp 'sequential_7/batch_normalization_6/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/moving_mean'>
<ReadVariableOp 'sequential_7/batch_normalization_6/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/moving_variance'>
<ReadVariableOp 'sequential_7/batch_normalization_6/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/conv2d_transpose_1/kernel'>
<ReadVariableOp 'sequential_8/conv2d_transpose_1/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/gamma'>
<ReadVariableOp 'sequential_8/batch_normalization_7/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/beta'>
<ReadVariableOp 'sequential_8/batch_normalization_7/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/moving_mean'>
<ReadVariableOp 'sequential_8/batch_normalization_7/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/moving_variance'>
<ReadVariableOp 'sequential_8/batch_normalization_7/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/conv2d_transpose_2/kernel'>
<ReadVariableOp 'sequential_9/conv2d_transpose_2/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/gamma'>
<ReadVariableOp 'sequential_9/batch_normalization_8/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/beta'>
<ReadVariableOp 'sequential_9/batch_normalization_8/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/moving_mean'>
<ReadVariableOp 'sequential_9/batch_normalization_8/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/moving_variance'>
<ReadVariableOp 'sequential_9/batch_normalization_8/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/conv2d_transpose_3/kernel'>
<ReadVariableOp 'sequential_10/conv2d_transpose_3/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/gamma'>
<ReadVariableOp 'sequential_10/batch_normalization_9/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/beta'>
<ReadVariableOp 'sequential_10/batch_normalization_9/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/moving_mean'>
<ReadVariableOp 'sequential_10/batch_normalization_9/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/moving_variance'>
<ReadVariableOp 'sequential_10/batch_normalization_9/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/conv2d_transpose_4/kernel'>
<ReadVariableOp 'sequential_11/conv2d_transpose_4/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/gamma'>
<ReadVariableOp 'sequential_11/batch_normalization_10/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/beta'>
<ReadVariableOp 'sequential_11/batch_normalization_10/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/moving_mean'>
<ReadVariableOp 'sequential_11/batch_normalization_10/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/moving_variance'>
<ReadVariableOp 'sequential_11/batch_normalization_10/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/conv2d_transpose_5/kernel'>
<ReadVariableOp 'sequential_12/conv2d_transpose_5/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/gamma'>
<ReadVariableOp 'sequential_12/batch_normalization_11/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/beta'>
<ReadVariableOp 'sequential_12/batch_normalization_11/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/moving_mean'>
<ReadVariableOp 'sequential_12/batch_normalization_11/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/moving_variance'>
<ReadVariableOp 'sequential_12/batch_normalization_11/moving_variance/Read/ReadVariableOp'>
<NoOp 'NoOp'>
<Const 'Const'>
<Placeholder 'serving_default_input_1'>
<StatefulPartitionedCall 'StatefulPartitionedCall'>
<Placeholder 'saver_filename'>
<StatefulPartitionedCall 'StatefulPartitionedCall_1'>
<StatefulPartitionedCall 'StatefulPartitionedCall_2'>
What should I exactly have to search in the model?
I have the same issue in providing training = true in tensorflow js. I wonder if there is any work around for the problem.