caffe-tensorflow icon indicating copy to clipboard operation
caffe-tensorflow copied to clipboard

Fine tuned the model in tensorflow

Open huyangc opened this issue 8 years ago • 7 comments

I am searching the way to change the last layer of the model to fine tuned the network. For example, the imagenet's classification is 1000, but I would like to fine tuned the alexnet into a 100 classes classification task. How can I change the model generated by the caffe-tensorflow? Is there any way to do this without changing the source code of caffe-tensorflow?

huyangc avatar Nov 22 '16 05:11 huyangc

+1

pribadihcr avatar Nov 24 '16 06:11 pribadihcr

+1

Patrickcxt avatar Nov 26 '16 16:11 Patrickcxt

+1

tportenier avatar Feb 06 '17 11:02 tportenier

Perhaps you can refer to my repo. It is modified based on caffe-tensorflow.

joelthchao avatar Feb 06 '17 17:02 joelthchao

@joelthchao Thanks for the hint. Inspired by your code, I modified the kaffe.tensorflow.Network.load to take a list of layer names that should be ignored when restoring, either because they are missing in the modified model or because one wants to train these weights from scratch. This indeed solves the problem for me!

tportenier avatar Feb 14 '17 12:02 tportenier

@tportenier Would it be possible to share the modification? I'm interested in doing similar.

SamComber avatar Jun 12 '18 22:06 SamComber

Here you go! This is what I used (long time ago):

def load(self, data_path, session, ignore_missing=False, scratch_layers=[]): '''Load network weights. data_path: The path to the numpy-serialized network weights session: The current TensorFlow session ignore_missing: If true, serialized weights for missing layers are ignored. scratch_layers: List of strings (layer names to be ignored during load) ''' data_dict = np.load(data_path).item() for op_name in data_dict: if op_name not in scratch_layers: with tf.variable_scope(op_name, reuse=True): for param_name, data in data_dict[op_name].iteritems(): try: var = tf.get_variable(param_name) session.run(var.assign(data)) except ValueError: if not ignore_missing: raise

tportenier avatar Jun 20 '18 14:06 tportenier