generative_inpainting icon indicating copy to clipboard operation
generative_inpainting copied to clipboard

How to run multiple images with multiple masks respectively

Open TrinhQuocNguyen opened this issue 6 years ago • 20 comments

Thank you for your contribution,

At the moment, I have checked the test file and it only can run on 1 image/mask. I have tried to put code in the for loop, but I got error at this line : output = model.build_server_graph(input_image) output = model.build_server_graph(input_image) File "/home/ubuntu/trinh/generative_inpainting/inpaint_model.py", line 307, in build_server_graph config=None) File "/home/ubuntu/trinh/generative_inpainting/inpaint_model.py", line 50, in build_inpaint_net x = gen_conv(x, cnum, 5, 1, name='conv1') File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args return func(*args, **current_args) File "/home/ubuntu/trinh/generative_inpainting/inpaint_ops.py", line 45, in gen_conv activation=activation, padding=padding, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/convolutional.py", line 608, in conv2d return layer.apply(inputs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 671, in apply return self.call(inputs, *args, **kwargs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 559, in call self.build(input_shapes[0]) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/convolutional.py", line 143, in build dtype=self.dtype) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 458, in add_variable trainable=trainable and self.trainable) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 1203, in get_variable constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 1092, in get_variable constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 425, in get_variable constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 394, in _true_getter use_resource=use_resource, constraint=constraint) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 742, in _get_single_variable name, "".join(traceback.format_list(tb)))) ValueError: Variable inpaint_net/conv1/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

File "/home/ubuntu/trinh/generative_inpainting/inpaint_ops.py", line 45, in gen_conv activation=activation, padding=padding, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args return func(*args, **current_args) File "/home/ubuntu/trinh/generative_inpainting/inpaint_model.py", line 50, in build_inpaint_net x = gen_conv(x, cnum, 5, 1, name='conv1')

Do you have any ideas what I have done wrong? Thank you.

TrinhQuocNguyen avatar Apr 04 '18 07:04 TrinhQuocNguyen

Oh thank you, I have found the answer: Just set the parameter reuse = tf.AUTO_REUSE output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE) The tensorflow will automatically understand and reuse the graph.

TrinhQuocNguyen avatar Apr 04 '18 08:04 TrinhQuocNguyen

It would be even more efficient if you can build graph ONCE with placeholder and feed your images with sess.run. A related issue can be found #8.

JiahuiYu avatar Apr 04 '18 14:04 JiahuiYu

Hello JiahuiYu, Thank you for your quick response. Did you mean sess.run ? I'm reading your source code to understand what you have done.

TrinhQuocNguyen avatar Apr 05 '18 08:04 TrinhQuocNguyen

Sorry typo.

JiahuiYu avatar Apr 05 '18 20:04 JiahuiYu

Hello JiahuiYu, Thank you for your response. I'm building the graph. In inpaint.yml file, at #loss legacy line. I have found that VGG_MOEL_FILE you have configured, I have read your paper, it did not mention transfer learning. So, I wonder whether we can use VGG16 network for transfer learning? Thank you for your concerns.

TrinhQuocNguyen avatar Apr 10 '18 05:04 TrinhQuocNguyen

"We have not found perceptual loss (reconstruction loss on VGG features), style loss (squared Frobenius norm of Gram matrix computed on the VGG features) [21] and total variation (TV) loss bring noticeable improvements for image inpainting in our framework, thus are not used."

You will need to implement VGG16 perceptual loss by yourself.

JiahuiYu avatar Apr 10 '18 05:04 JiahuiYu

Thank you for your fast response. I have used your pretrained model to apply transfer learning, it saved me a lot of time on a new training set. I am reading your paper again, I think it's a great paper.

TrinhQuocNguyen avatar Apr 11 '18 06:04 TrinhQuocNguyen

Hello Jiahuiyu, Thank you for your awesome code, I have tried to modify and build the graph, but unfortunately I could not build it.
I have found that you have used build_server_graph function, but I don't understand it much. Could you please add some code you have built the graph and feed image by image into it? Thank you in advance.

TrinhQuocNguyen avatar Apr 17 '18 01:04 TrinhQuocNguyen

Here is my code at the moment: use a for loop

# prepare folder path
    input_folder = args.test_dir + "/input"
    mask_folder = args.test_dir + "/mask"
    output_folder = args.test_dir + "/output_" + args.checkpoint_dir.split("/")[1] + "_" +datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # start sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    


    dir_files = os.listdir(input_folder)
    dir_files.sort()

    for file_inter in dir_files:
        sess = tf.Session(config=sess_config)
        
        base_file_name = os.path.basename(file_inter)

        image = cv2.imread(input_folder + "/" + base_file_name)
        mask = cv2.imread(mask_folder + "/" + base_file_name)

        assert image.shape == mask.shape

        h, w, _ = image.shape
        grid = 1
        image = image[:h//grid*grid, :w//grid*grid, :]
        mask = mask[:h//grid*grid, :w//grid*grid, :]
        print('Shape of image: {}'.format(image.shape))

        image = np.expand_dims(image, 0)
        mask = np.expand_dims(mask, 0)
        input_image = np.concatenate([image, mask], axis=2)

        input_image = tf.constant(input_image, dtype=tf.float32)
        output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE)
        output = (output + 1.) * 127.5
        output = tf.reverse(output, [-1])
        output = tf.saturate_cast(output, tf.uint8)
        # load pretrained model
        vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        assign_ops = []
        for var in vars_list:
            vname = var.name
            from_name = vname
            var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
            assign_ops.append(tf.assign(var, var_value))
        sess.run(assign_ops)
        print('Model loaded.')
        result = sess.run(output)

        # write to output folder
        cv2.imwrite(output_folder + "/" + base_file_name, result[0][:, :, ::-1])
        sess.close()

TrinhQuocNguyen avatar Apr 17 '18 01:04 TrinhQuocNguyen

Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:

    sess_config = tf.ConfigProto()                                                                                                                                                                                                            
    sess_config.gpu_options.allow_growth = True                                                                                                                                                                                               
    sess = tf.Session(config=sess_config)                                                                                                                                                                                                     
                                                                                                                                                                                                                                              
    model = InpaintCAModel()                                                                                                                                                                                                                  
    input_image_ph = tf.placeholder(                                                                                                                                                                                                          
        tf.float32, shape=(1, args.image_height, args.image_width*2, 3))                                                                                                                                                                      
    output = model.build_server_graph(input_image_ph)                                                                                                                                                                                         
    output = (output + 1.) * 127.5                                                                                                                                                                                                            
    output = tf.reverse(output, [-1])                                                                                                                                                                                                         
    output = tf.saturate_cast(output, tf.uint8)                                                                                                                                                                                               
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)                                                                                                                                                                              
    assign_ops = []                                                                                                                                                                                                                           
    for var in vars_list:                                                                                                                                                                                                                     
        vname = var.name                                                                                                                                                                                                                      
        from_name = vname                                                                                                                                                                                                                     
        var_value = tf.contrib.framework.load_variable(                                                                                                                                                                                       
            args.checkpoint_dir, from_name)                                                                                                                                                                                                   
        assign_ops.append(tf.assign(var, var_value))                                                                                                                                                                                          
    sess.run(assign_ops)                                                                                                                                                                                                                      
    print('Model loaded.')                                                                                                                                                                                                                    
                                                                                                                                                                                                                                              
    with open(args.flist, 'r') as f:                                                                                                                                                                                                          
        lines = f.read().splitlines()                                                                                                                                                                                                         
    t = time.time()                                                                                                                                                                                                                           
    for line in lines:                                                                                                                                                                                                                                                                                                                                                                                                                                     
        image, mask, out = line.split()                                                                                                                                                                                                       
        base = os.path.basename(mask)                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
        image = cv2.imread(image)                                                                                                                                                                                                             
        mask = cv2.imread(mask)                                                                                                                                                                                                               
        image = cv2.resize(image, (args.image_width, args.image_height))                                                                                                                                                                      
        mask = cv2.resize(mask, (args.image_width, args.image_height))                                                                                                                                                                        
        # cv2.imwrite(out, image*(1-mask/255.) + mask)                                                                                                                                                                                        
        # # continue                                                                                                                                                                                                                          
        # image = np.zeros((128, 256, 3))                                                                                                                                                                                                     
        # mask = np.zeros((128, 256, 3))                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        assert image.shape == mask.shape                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        h, w, _ = image.shape                                                                                                                                                                                                                 
        grid = 4                                                                                                                                                                                                                              
        image = image[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                        
        mask = mask[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                          
        print('Shape of image: {}'.format(image.shape))                                                                                                                                                                                       
                                                                                                                                                                                                                                              
        image = np.expand_dims(image, 0)                                                                                                                                                                                                      
        mask = np.expand_dims(mask, 0)                                                                                                                                                                                                        
        input_image = np.concatenate([image, mask], axis=2)                                                                                                                                                                                   
                                                                                                                                                                                                                                              
        # load pretrained model                                                                                                                                                                                                               
        result = sess.run(output, feed_dict={input_image_ph: input_image})                                                                                                                                                                    
        print('Processed: {}'.format(out))                                                                                                                                                                                                    
        cv2.imwrite(out, result[0][:, :, ::-1])                                                                                                                                                                                               
                                                                                                                                                                                                                                              
    print('Time total: {}'.format(time.time() - t)) 

JiahuiYu avatar Apr 17 '18 02:04 JiahuiYu

Hi JiahuiYu , Thank you very much for your code and your contribution. I am so excited to check it out. Thank you again 😄 😄 😄 😄 😄 😄 😄 😄 😄 😄

TrinhQuocNguyen avatar Apr 17 '18 02:04 TrinhQuocNguyen

Hi JiahuiYu , wow, it worked. Thank you very much, you have saved me tons of time. 😍 😍 😍

TrinhQuocNguyen avatar Apr 17 '18 06:04 TrinhQuocNguyen

No problem. :)

JiahuiYu avatar Apr 17 '18 16:04 JiahuiYu

These codes should be added to the master branch 😍 😍 😍

Bingmang avatar May 22 '19 05:05 Bingmang

These codes should be added to the master branch 😍 😍 😍

@Bingmang Is the code added to the for loop of test.py? Thank you

TianLuluC avatar Aug 01 '19 08:08 TianLuluC

I have made this thread open so others can have a reference.

JiahuiYu avatar Aug 03 '19 19:08 JiahuiYu

@TrinhQuocNguyen Thank you very much for your discussions about training a new model! And could you give me more instructions to pre-train a model with transfer learning? Thanks a lot !

zylxadz avatar Apr 20 '20 08:04 zylxadz

great!

minushuang avatar May 21 '20 09:05 minushuang

Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:

    sess_config = tf.ConfigProto()                                                                                                                                                                                                            
    sess_config.gpu_options.allow_growth = True                                                                                                                                                                                               
    sess = tf.Session(config=sess_config)                                                                                                                                                                                                     
                                                                                                                                                                                                                                              
    model = InpaintCAModel()                                                                                                                                                                                                                  
    input_image_ph = tf.placeholder(                                                                                                                                                                                                          
        tf.float32, shape=(1, args.image_height, args.image_width*2, 3))                                                                                                                                                                      
    output = model.build_server_graph(input_image_ph)                                                                                                                                                                                         
    output = (output + 1.) * 127.5                                                                                                                                                                                                            
    output = tf.reverse(output, [-1])                                                                                                                                                                                                         
    output = tf.saturate_cast(output, tf.uint8)                                                                                                                                                                                               
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)                                                                                                                                                                              
    assign_ops = []                                                                                                                                                                                                                           
    for var in vars_list:                                                                                                                                                                                                                     
        vname = var.name                                                                                                                                                                                                                      
        from_name = vname                                                                                                                                                                                                                     
        var_value = tf.contrib.framework.load_variable(                                                                                                                                                                                       
            args.checkpoint_dir, from_name)                                                                                                                                                                                                   
        assign_ops.append(tf.assign(var, var_value))                                                                                                                                                                                          
    sess.run(assign_ops)                                                                                                                                                                                                                      
    print('Model loaded.')                                                                                                                                                                                                                    
                                                                                                                                                                                                                                              
    with open(args.flist, 'r') as f:                                                                                                                                                                                                          
        lines = f.read().splitlines()                                                                                                                                                                                                         
    t = time.time()                                                                                                                                                                                                                           
    for line in lines:                                                                                                                                                                                                                                                                                                                                                                                                                                     
        image, mask, out = line.split()                                                                                                                                                                                                       
        base = os.path.basename(mask)                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
        image = cv2.imread(image)                                                                                                                                                                                                             
        mask = cv2.imread(mask)                                                                                                                                                                                                               
        image = cv2.resize(image, (args.image_width, args.image_height))                                                                                                                                                                      
        mask = cv2.resize(mask, (args.image_width, args.image_height))                                                                                                                                                                        
        # cv2.imwrite(out, image*(1-mask/255.) + mask)                                                                                                                                                                                        
        # # continue                                                                                                                                                                                                                          
        # image = np.zeros((128, 256, 3))                                                                                                                                                                                                     
        # mask = np.zeros((128, 256, 3))                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        assert image.shape == mask.shape                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
        h, w, _ = image.shape                                                                                                                                                                                                                 
        grid = 4                                                                                                                                                                                                                              
        image = image[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                        
        mask = mask[:h//grid*grid, :w//grid*grid, :]                                                                                                                                                                                          
        print('Shape of image: {}'.format(image.shape))                                                                                                                                                                                       
                                                                                                                                                                                                                                              
        image = np.expand_dims(image, 0)                                                                                                                                                                                                      
        mask = np.expand_dims(mask, 0)                                                                                                                                                                                                        
        input_image = np.concatenate([image, mask], axis=2)                                                                                                                                                                                   
                                                                                                                                                                                                                                              
        # load pretrained model                                                                                                                                                                                                               
        result = sess.run(output, feed_dict={input_image_ph: input_image})                                                                                                                                                                    
        print('Processed: {}'.format(out))                                                                                                                                                                                                    
        cv2.imwrite(out, result[0][:, :, ::-1])                                                                                                                                                                                               
                                                                                                                                                                                                                                              
    print('Time total: {}'.format(time.time() - t)) 

Should be:

    output = model.build_server_graph(FLAGS, input_image_ph)                                                                                                                                                                         

JeremyCJM avatar Oct 19 '20 14:10 JeremyCJM

Hey I'm trying since days to customize some part, can you explain me how to access model and run model.summary() ???

arnavmehta7 avatar Oct 11 '22 05:10 arnavmehta7