NiftyNet icon indicating copy to clipboard operation
NiftyNet copied to clipboard

Using pre-trained model in a way like TensorflowHub

Open johnnychhsu opened this issue 5 years ago • 11 comments

I am wondering if we can use NiftyNet pre-trained model in a way like TensorflowHub, such as the example from TensorflowHub :

with tf.Graph().as_default():
  module_url = "https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1"
  embed = hub.Module(module_url)
  embeddings = embed(["A long sentence.", "single-word",
                      "http://example.com"])

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())

    print(sess.run(embeddings))

I want to do transfer learning using the pre-trained model on my dataset, I think I can modify the config file to do what I want, but I am used to the tensorflowhub way. Thus asking this question. So the question are :

  1. Is it possible to do the above way now? If yes, I am thinking maybe I can add some example and issue pull request after I finish my work.
  2. If not, is it possible in the future ?

Thank you!

johnnychhsu avatar Oct 22 '18 02:10 johnnychhsu

Hi @johnnychhsu an example of transfer learning is presented here -- https://github.com/NifTK/NiftyNet/pull/140. It's still a work in progress, mainly we need a generic user interface (via config file) to handle these changes so that the user is able to specify which variables to be initialised from checkpoint files and the others from random initialisation. If you are also interested in this direction, we could collaborate on it.

Alternatively if you want to have full control of the graph, the code example attached could be a good starting point. A caveat is that the preprocessing layers (not included here) should exactly follow the ones used for training e.g. https://github.com/NifTK/NiftyNet/blob/v0.3.0/niftynet/application/segmentation_application.py#L178.

import tensorflow as tf
import os
from niftynet.io.image_reader import ImageReader
from niftynet.engine.sampler_resize import ResizeSampler
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"]='1'


##### Address of the model to be restored
check_point_location='/home/niftynet/models/dense_vnet_abdominal_ct/models/model.ckpt-3000'
#####

##### Create a sampler

data_param = {'image': {'path_to_search': '~/niftynet/data/dense_vnet_abdominal_ct',
                        'filename_contains': 'CT', 'spatial_window_size': (144, 144, 144)}}

reader = ImageReader().initialise(data_param)

sampler = ResizeSampler(
    reader=reader,
    data_param=data_param,
    batch_size=1,
    shuffle_buffer=True,
    queue_length=35)

#####

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:

    sampler.run_threads(sess, tf.train.Coordinator(), num_threads=1)

    from niftynet.network.dense_vnet import DenseVNet
    data_dict = sampler.pop_batch_op()
    net_logits = DenseVNet(num_classes=9)(data_dict['image'])

    # restore the variables
    saver = tf.train.Saver()
    saver.restore(sess, check_point_location)

    net_logits = sess.run(net_logits)
    print(net_logits.shape)

wyli avatar Oct 22 '18 10:10 wyli

Thank you! I would like to work on it, any suggestion ?

I think I can try this, thanks!

johnnychhsu avatar Oct 23 '18 02:10 johnnychhsu

@wyli Hi, I tried to modify the code attached, but got some problems.

import tensorflow as tf
import os
import pdb
from niftynet.io.image_reader import ImageReader
###
only sampler_resize_v2 exist in dev
###
from niftynet.engine.sampler_resize_v2 import ResizeSampler
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"]='3'


##### Address of the model to be restored
check_point_location='/home/niftynet/models/dense_vnet_abdominal_ct/models/model.ckpt-3000'
#####

##### Create a sampler
data_param = {'image': {'path_to_search': '~/niftynet/data/dense_vnet_abdominal_ct',
                        'filename_contains': 'CT', 'spatial_window_size': (144, 144, 144)}}


reader = ImageReader().initialise(data_param)

sampler = ResizeSampler(
    reader=reader,
    window_sizes=(144, 144, 144),
    batch_size=1,
    shuffle=True,
    queue_length=35)

#####

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
    sampler.run_threads(sess, tf.train.Coordinator(), num_threads=1)
    from niftynet.network.dense_vnet import DenseVNet
    data_dict = sampler.layer_op()
    # pdb.set_trace()
    input_data = tf.placeholder(tf.float32, shape=(None, 144, 144, 144, 1, 1))
    ###
    I think this should be defined by placeholder as a computation graph first, then feed the actual data 
    into the model
    ###
    net_logits = DenseVNet(num_classes=9)(input_data)
    pdb.set_trace()

    # restore the variables
    saver = tf.train.Saver()
    saver.restore(sess, check_point_location)

    net_logits = sess.run(net_logits, feed_dict={input_data: data_dict['image']})
    print(net_logits.shape)

The sampler in branch is sampler_resize_v2. The input is tensor and feed into the graph when run sess.

However still got some error. The error is

Traceback (most recent call last):
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 518, in make_tensor_proto
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 518, in <listcomp>
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/util/compat.py", line 67, in as_bytes
    (bytes_or_text,))
TypeError: Expected binary or unicode string, got None

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test_transfer.py", line 36, in <module>
    net_logits = DenseVNet(num_classes=9)(input_data)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/base_layer.py", line 34, in __call__
    return self._op(*args, **kwargs)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 351, in __call__
    result = self._call_func(args, kwargs)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 302, in _call_func
    result = self._func(*args, **kwargs)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/network/dense_vnet.py", line 195, in layer_op
    input_tensor = augment_layer(input_tensor)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/base_layer.py", line 34, in __call__
    return self._op(*args, **kwargs)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 351, in __call__
    result = self._call_func(args, kwargs)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 302, in _call_func
    result = self._func(*args, **kwargs)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/affine_augmentation.py", line 79, in layer_op
    batch_size, spatial_rank)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/affine_augmentation.py", line 56, in _random_transform
    output_corners = tf.tile([output_corners], [batch_size, 1, 1])
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 5587, in tile
    "Tile", input=input, multiples=multiples, name=name)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 513, in _apply_op_helper
    raise err
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 510, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1036, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 522, in make_tensor_proto
    "supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [None, 1, 1]. Consider casting elements to a supported type.

originally defined at:
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/network/dense_vnet.py", line 194, in layer_op
    hyper['augmentation_scale'], 'LINEAR', 'ZERO')
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/affine_augmentation.py", line 35, in __init__
    Layer.__init__(self, name=name)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/base_layer.py", line 26, in __init__
    self._op = tf.make_template(name, self.layer_op, create_scope_now_=True)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 152, in make_template
    **kwargs)


originally defined at:
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/network/dense_vnet.py", line 93, in __init__
    name=name)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/network/base_net.py", line 21, in __init__
    super(BaseNet, self).__init__(name=name)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/base_layer.py", line 58, in __init__
    super(TrainableLayer, self).__init__(name=name)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/niftynet/layer/base_layer.py", line 26, in __init__
    self._op = tf.make_template(name, self.layer_op, create_scope_now_=True)
  File "/home/dsa321123321/test/lib/python3.5/site-packages/tensorflow/python/ops/template.py", line 152, in make_template
    **kwargs)

Do you have any idea? Thank you!

johnnychhsu avatar Oct 26 '18 04:10 johnnychhsu

Hi @johnnychhsu The problem in your code is the layer_op() call. The following works for me:

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:                                                                                                                                                      
    from niftynet.network.dense_vnet import DenseVNet                                                                                                                                                                                                       
    data_dict = sampler.pop_batch_op()                                                                                                                                                                                                                      
    net_logits = DenseVNet(num_classes=9)(data_dict['image'])                                                                                                                                                                                               
    saver = tf.train.Saver()                                                                                                                                                                                                                                
    saver.restore(sess, check_point_location)                                                                                                                                                                                                               
                                                                                                                                                                                                                                                            
    net_logits = sess.run(net_logits)                                                                                                                                                                                                                       
    print(net_logits.shape) 

If you want to use placeholders, one possibility would be overriding the sampler's dataset initialisation (https://github.com/NifTK/NiftyNet/blob/dev/niftynet/engine/image_window_dataset.py#L207) with something like tf.dataset.from_tensor_slices(placeholders).

wyli avatar Oct 26 '18 10:10 wyli

@wyli Thank you! If I want to work on this (transfer learning), do you have any suggestion to start?

johnnychhsu avatar Oct 26 '18 15:10 johnnychhsu

Yes @johnnychhsu, there are a few tasks in this direction:

  1. the transfer learning interface has been improved by @aleks-djuric https://github.com/NifTK/NiftyNet/pull/258, it would be great to have a step-by-step tutorial for transfer learning in https://github.com/NifTK/NiftyNet/tree/dev/demos (python notebooks + example configuration files).

  2. make a wrapper around niftynet's IO module, so that we can finetune a TensorflowHub module with training samples generated using niftynet, pseudocode:

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
    # TensorFlowHub 
    module_url = "https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1"
    segmentation_net = hub.Module(module_url)
    my_image_placeholder = tf.Placeholder(...)
    output = segmentation_net(my_image_placeholder)
    ...

    # loading data with niftynet's IO
    data_dict = sampler()
    # TODO: make data_dict compatible with segmentation_net
    sess.run(finetuning_model_op, feed_dict={my_image_placeholder: data_dict})
    ...
  1. improve niftynet's application module, so that we can initialise a TensorflowHub module within the niftynet applications, i.e., extending the network initialisation function https://github.com/NifTK/NiftyNet/blob/v0.4.0/niftynet/application/segmentation_application.py#L257 ideally it would take a tfhub module_url as input, and return an initialised network instance.

Combining these features together we would be able to finetue a tensorflowhub model in niftynet without writing Python code :)

wyli avatar Oct 26 '18 18:10 wyli

HI @wyli I'd like to work on these steps by steps, I think I can start with the first one! Thank you!

johnnychhsu avatar Oct 28 '18 11:10 johnnychhsu

@johnnychhsu I want to use niftynet pretrained segmentation model for segmenting custom data. I downloaded the pre trained weights and and modified model_dir path to downloaded one. However when I run python3 net_segment.py train -c /home/Container_data/config/promise12_demo_train_config.ini I am getting the error below. Caused by op 'save/Assign_17', defined at: File "net_segment.py", line 8, in <module> sys.exit(main()) File "/home/NiftyNet/niftynet/__init__.py", line 142, in main app_driver.run(app_driver.app) File "/home/NiftyNet/niftynet/engine/application_driver.py", line 197, in run SESS_STARTED.send(application, iter_msg=None) File "/usr/local/lib/python3.5/dist-packages/blinker/base.py", line 267, in send for receiver in self.receivers_for(sender)] File "/usr/local/lib/python3.5/dist-packages/blinker/base.py", line 267, in <listcomp> for receiver in self.receivers_for(sender)] File "/home/NiftyNet/niftynet/engine/handler_model.py", line 109, in restore_model var_list=to_restore, save_relative_paths=True) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1102, in __init__ self.build() File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1114, in build self._build(self._filename, build_save=True, build_restore=True) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1151, in _build build_save=build_save, build_restore=build_restore) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 795, in _build_internal restore_sequentially, reshape) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 428, in _AddRestoreOps assign_ops.append(saveable.restore(saveable_tensors, shapes)) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 119, in restore self.op.get_shape().is_fully_defined()) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/state_ops.py", line 221, in assign validate_shape=validate_shape) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 61, in assign use_locking=use_locking, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 488, in new_func return func(*args, **kwargs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3274, in create_op op_def=op_def) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1770, in __init__ self._traceback = tf_stack.extract_stack() InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error: Assign requires shapes of both tensors to match. lhs shape= [3,3,61,256] rhs shape= [3,3,3,61,9] [[node save/Assign_17 (defined at /home/NiftyNet/niftynet/engine/handler_model.py:109) = Assign[T=DT_FLOAT, _class=["loc:@DenseVNet/conv/conv_/w"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](DenseVNet/conv/conv_/w, save/RestoreV2/_35)

AkhilaPerumalla123 avatar Feb 13 '19 08:02 AkhilaPerumalla123

Your custom data doesn't seem to be the same dimension as the data the network was trained on.

Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error: Assign requires shapes of both tensors to match. lhs shape= [3,3,61,256] rhs shape= [3,3,3,61,9]

Without knowing more about what you're trying to do, we can't help you. I would suggest following these steps before requesting additional help however.

alekswithakayy avatar Feb 13 '19 19:02 alekswithakayy

@johnnychhsu were you able to use placeholders with niftynet pipeline? I'm unable to restore the model when using a placeholder, please let me know you fixed this issue.

Thank you

koriavinash1 avatar Jun 17 '19 23:06 koriavinash1

@wyli is it possible to train a network in a similar method of tf Sessions?

jillianlee avatar Jun 25 '20 01:06 jillianlee