java icon indicating copy to clipboard operation
java copied to clipboard

How to add training = True in tensorflow-java?

Open micosacak opened this issue 4 years ago • 8 comments
trafficstars

The models that I have from Python works only if I provide training = True as well in Python.

for instance in Python;

import tensorflow as tf
my_model = tf.keras.models.load_model("model")
prdct = my_model(input_image, training = True)

Here, if I do not provide training = True, the values in prdct.numpy() are all "nan".

I have the same problem in tensorflow/java as well. I wonder if there is a way to give option training = True in tensorflow/java? Here is the java code:

        TFloat32 ImagePredicted = (TFloat32) sess
        		.runner()
        		.feed(input_layer, inputTensor)
        		.fetch(outputLayer)
        		.run()
        		.get(0);

micosacak avatar Apr 09 '21 22:04 micosacak

Depending on how you've loaded your model in and how it's specified, then you might be able to feed in a boolean (scalar) tensor that specifies if it's training or not. I've done this before in using the TF 1.x bindings to change the behaviour of a dropout layer. However that's going to depend on the model you've built in Keras, and how it gets exported from Keras and loaded into TF-Java.

Craigacp avatar Apr 10 '21 01:04 Craigacp

This is the model in Python/tensorflow (version 2.2 and 2.4.1) and Java TF2 (tensorflow-core-platform, version 0.3.1);

the details of the model can be found : here

import tensorflow as tf

def downsample(filters, size, apply_batchnorm = True):
	#initializer = tf.random_normal_initializer(0., 0.02)
	initializer = "he_normal"
	result = tf.keras.Sequential()
	result.add(
	  tf.keras.layers.Conv2D(filters, size, strides = 2, padding='same', kernel_initializer=initializer, use_bias=False))
	if apply_batchnorm:
		result.add(tf.keras.layers.BatchNormalization())
	result.add(tf.keras.layers.LeakyReLU())
	return result

def upsample(filters, size, apply_dropout=False):
	#initializer = tf.random_normal_initializer(0., 0.02)
	initializer = "he_normal"
	result = tf.keras.Sequential()
	result.add(
	tf.keras.layers.Conv2DTranspose(filters, size, strides = 2, padding='same', kernel_initializer=initializer, use_bias=False))
	result.add(tf.keras.layers.BatchNormalization())
	if apply_dropout:
		result.add(tf.keras.layers.Dropout(0.5))
	result.add(tf.keras.layers.LeakyReLU())
	return result

def Generator(input_shape = (128,128,1)):
	inputs = tf.keras.layers.Input(input_shape)
	down_stack = [
	downsample(64, 4, apply_batchnorm = False),
	downsample(128, 4), 
	downsample(256, 4),
	downsample(256, 4),
	downsample(512, 4),
	downsample(512, 4),
	downsample(512, 4),
	]
	up_stack = [
	upsample(512, 4, apply_dropout = True), 
	upsample(512, 4, apply_dropout = True), 
	upsample(512, 4, apply_dropout = True), 
	upsample(256, 4),
	upsample(256, 4),
	upsample(128, 4), 
	upsample(64, 4),
	]
	initializer = "he_normal"
	last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides = 2, padding='same', kernel_initializer=initializer, activation='tanh') 
	x = inputs
	skips = []
	for down in down_stack:
		x = down(x)
		skips.append(x)
	skips = reversed(skips[:-1])
	for up, skip in zip(up_stack, skips):
		x = up(x)
		x = tf.keras.layers.Concatenate()([x, skip])
	x = last(x)
	return tf.keras.Model(inputs = inputs, outputs = x)

generator = Generator(input_shape)

after training, I simply save the model as:
generator.save("model")

then, I load it in Python as below and set training = True;

my_model = tf.keras.models.load_model("model")
prdct = my_model(input_image, training = True)  #input image is (1,128,128,1) a numpy array

Now in Java (tensorflow-core-platform, version 0.3.1) I do the following.

        SavedModelBundle theModel = SavedModelBundle.load("model", "serve");
        Session sess = theModel.session();
        
        // convert Buffered image to FloatArray and then to Tensor
        float[][][][] floatImage = ConvertTtoFloatImage(image, imgHeight, imgWidth); // imgHeight and imgWidth are 128, 128
        TFloat32 inputTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(floatImage));
              
       // perform prediction 
        
        TFloat32 ImagePredicted = (TFloat32) sess
        		.runner()
        		.feed("serving_default_input_1", inpTensor)
        		.fetch("StatefulPartitionedCall")
        		.run()
        		.get(0);

Now all values in ImagePredicted are -1.

micosacak avatar Apr 10 '21 17:04 micosacak

I am not sure if this could be a solution.

public void feed(String inputName, boolean[] src, long... dims) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; i++) {
    b[i] = src[i] ? (byte) 1 : (byte) 0;
  }
  addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}

as it raises this error (The method create(Class<Boolean>, long[], ByteBuffer) is undefined for the type Tensor) for Tensor.create, I could not test it.

micosacak avatar Apr 10 '21 23:04 micosacak

You should use TBool.class rather than Boolean.class. However I'm not sure that that will achieve what you want. The effect of the training=true flag in Keras depends on the layers used, and how the model has been constructed (see the call definition here - https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/engine/training.py#L428). So if that eventually unfolds into a graph which expects a boolean tensor with a specific name then you can do that (and it tended to in TF 1.x). However Keras is more tightly integrated into TF 2.x and many of the layers have more complex behaviour. So you'll need to trace through what your Python code does and figure out what the effect of that flag is for your specific model.

Craigacp avatar Apr 11 '21 15:04 Craigacp

Thanks for the explanation, now I could understand the logic of feed, thus one can feed several parameters, separately. Coming from Python to Java, for a beginner is really difficult to understand the logic in Java. I will write a detailed answer again if I solve the problem. Again thanks for your explanation.

micosacak avatar Apr 12 '21 07:04 micosacak

The feeding logic is pretty much exactly how it worked in tf 1. We're working on function support which will line up with how tf.function works in tf 2, this is already available for inference using saved model bundle in TF Java.

Craigacp avatar Apr 12 '21 11:04 Craigacp

Unfortunately, I have not found how to feed in the Boolean in my model.

Here is the Operations in the graph and I cannot find one that might need boolean feed.

<VarHandleOp 'conv2d_transpose_7/kernel'>
<ReadVariableOp 'conv2d_transpose_7/kernel/Read/ReadVariableOp'>
<VarHandleOp 'conv2d_transpose_7/bias'>
<ReadVariableOp 'conv2d_transpose_7/bias/Read/ReadVariableOp'>
<VarHandleOp 'sequential/conv2d/kernel'>
<ReadVariableOp 'sequential/conv2d/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential/p_re_lu/alpha'>
<ReadVariableOp 'sequential/p_re_lu/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/conv2d_1/kernel'>
<ReadVariableOp 'sequential_1/conv2d_1/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/gamma'>
<ReadVariableOp 'sequential_1/batch_normalization/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/beta'>
<ReadVariableOp 'sequential_1/batch_normalization/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/moving_mean'>
<ReadVariableOp 'sequential_1/batch_normalization/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/batch_normalization/moving_variance'>
<ReadVariableOp 'sequential_1/batch_normalization/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_1/p_re_lu_1/alpha'>
<ReadVariableOp 'sequential_1/p_re_lu_1/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/conv2d_2/kernel'>
<ReadVariableOp 'sequential_2/conv2d_2/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/gamma'>
<ReadVariableOp 'sequential_2/batch_normalization_1/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/beta'>
<ReadVariableOp 'sequential_2/batch_normalization_1/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/moving_mean'>
<ReadVariableOp 'sequential_2/batch_normalization_1/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/batch_normalization_1/moving_variance'>
<ReadVariableOp 'sequential_2/batch_normalization_1/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_2/p_re_lu_2/alpha'>
<ReadVariableOp 'sequential_2/p_re_lu_2/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/conv2d_3/kernel'>
<ReadVariableOp 'sequential_3/conv2d_3/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/gamma'>
<ReadVariableOp 'sequential_3/batch_normalization_2/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/beta'>
<ReadVariableOp 'sequential_3/batch_normalization_2/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/moving_mean'>
<ReadVariableOp 'sequential_3/batch_normalization_2/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/batch_normalization_2/moving_variance'>
<ReadVariableOp 'sequential_3/batch_normalization_2/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_3/p_re_lu_3/alpha'>
<ReadVariableOp 'sequential_3/p_re_lu_3/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/conv2d_4/kernel'>
<ReadVariableOp 'sequential_4/conv2d_4/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/gamma'>
<ReadVariableOp 'sequential_4/batch_normalization_3/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/beta'>
<ReadVariableOp 'sequential_4/batch_normalization_3/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/moving_mean'>
<ReadVariableOp 'sequential_4/batch_normalization_3/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/batch_normalization_3/moving_variance'>
<ReadVariableOp 'sequential_4/batch_normalization_3/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_4/p_re_lu_4/alpha'>
<ReadVariableOp 'sequential_4/p_re_lu_4/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/conv2d_5/kernel'>
<ReadVariableOp 'sequential_5/conv2d_5/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/gamma'>
<ReadVariableOp 'sequential_5/batch_normalization_4/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/beta'>
<ReadVariableOp 'sequential_5/batch_normalization_4/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/moving_mean'>
<ReadVariableOp 'sequential_5/batch_normalization_4/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/batch_normalization_4/moving_variance'>
<ReadVariableOp 'sequential_5/batch_normalization_4/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_5/p_re_lu_5/alpha'>
<ReadVariableOp 'sequential_5/p_re_lu_5/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/conv2d_6/kernel'>
<ReadVariableOp 'sequential_6/conv2d_6/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/gamma'>
<ReadVariableOp 'sequential_6/batch_normalization_5/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/beta'>
<ReadVariableOp 'sequential_6/batch_normalization_5/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/moving_mean'>
<ReadVariableOp 'sequential_6/batch_normalization_5/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/batch_normalization_5/moving_variance'>
<ReadVariableOp 'sequential_6/batch_normalization_5/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_6/p_re_lu_6/alpha'>
<ReadVariableOp 'sequential_6/p_re_lu_6/alpha/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/conv2d_transpose/kernel'>
<ReadVariableOp 'sequential_7/conv2d_transpose/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/gamma'>
<ReadVariableOp 'sequential_7/batch_normalization_6/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/beta'>
<ReadVariableOp 'sequential_7/batch_normalization_6/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/moving_mean'>
<ReadVariableOp 'sequential_7/batch_normalization_6/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_7/batch_normalization_6/moving_variance'>
<ReadVariableOp 'sequential_7/batch_normalization_6/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/conv2d_transpose_1/kernel'>
<ReadVariableOp 'sequential_8/conv2d_transpose_1/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/gamma'>
<ReadVariableOp 'sequential_8/batch_normalization_7/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/beta'>
<ReadVariableOp 'sequential_8/batch_normalization_7/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/moving_mean'>
<ReadVariableOp 'sequential_8/batch_normalization_7/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_8/batch_normalization_7/moving_variance'>
<ReadVariableOp 'sequential_8/batch_normalization_7/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/conv2d_transpose_2/kernel'>
<ReadVariableOp 'sequential_9/conv2d_transpose_2/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/gamma'>
<ReadVariableOp 'sequential_9/batch_normalization_8/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/beta'>
<ReadVariableOp 'sequential_9/batch_normalization_8/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/moving_mean'>
<ReadVariableOp 'sequential_9/batch_normalization_8/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_9/batch_normalization_8/moving_variance'>
<ReadVariableOp 'sequential_9/batch_normalization_8/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/conv2d_transpose_3/kernel'>
<ReadVariableOp 'sequential_10/conv2d_transpose_3/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/gamma'>
<ReadVariableOp 'sequential_10/batch_normalization_9/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/beta'>
<ReadVariableOp 'sequential_10/batch_normalization_9/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/moving_mean'>
<ReadVariableOp 'sequential_10/batch_normalization_9/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_10/batch_normalization_9/moving_variance'>
<ReadVariableOp 'sequential_10/batch_normalization_9/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/conv2d_transpose_4/kernel'>
<ReadVariableOp 'sequential_11/conv2d_transpose_4/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/gamma'>
<ReadVariableOp 'sequential_11/batch_normalization_10/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/beta'>
<ReadVariableOp 'sequential_11/batch_normalization_10/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/moving_mean'>
<ReadVariableOp 'sequential_11/batch_normalization_10/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_11/batch_normalization_10/moving_variance'>
<ReadVariableOp 'sequential_11/batch_normalization_10/moving_variance/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/conv2d_transpose_5/kernel'>
<ReadVariableOp 'sequential_12/conv2d_transpose_5/kernel/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/gamma'>
<ReadVariableOp 'sequential_12/batch_normalization_11/gamma/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/beta'>
<ReadVariableOp 'sequential_12/batch_normalization_11/beta/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/moving_mean'>
<ReadVariableOp 'sequential_12/batch_normalization_11/moving_mean/Read/ReadVariableOp'>
<VarHandleOp 'sequential_12/batch_normalization_11/moving_variance'>
<ReadVariableOp 'sequential_12/batch_normalization_11/moving_variance/Read/ReadVariableOp'>
<NoOp 'NoOp'>
<Const 'Const'>
<Placeholder 'serving_default_input_1'>
<StatefulPartitionedCall 'StatefulPartitionedCall'>
<Placeholder 'saver_filename'>
<StatefulPartitionedCall 'StatefulPartitionedCall_1'>
<StatefulPartitionedCall 'StatefulPartitionedCall_2'>

What should I exactly have to search in the model?

micosacak avatar Apr 20 '21 21:04 micosacak

I have the same issue in providing training = true in tensorflow js. I wonder if there is any work around for the problem.

Anjalivenugopal99 avatar Jun 14 '21 20:06 Anjalivenugopal99