Age-Gender-Estimate-TF
Age-Gender-Estimate-TF copied to clipboard
模型固化问题
能在train.py训练模型的同时,保存为ckpt格式之外,再保存为一份.pb格式的么? 我在网上找了一段保存为pb的代码 from tensorflow.python.framework import graph_util
##################模型封装 output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["output"]) # 形参output_node_names用于指定输出的节点名称(不是指的要输出的节点的名称) pb_file_path = " my_net/save_net.pb" with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(output_graph_def.SerializeToString())
##################模型封装 但是因为是初学者没有找到output_node_names(用于指定输出的节点名称)在你的代码中具体变量名是哪个,能指点一下么?谢谢。
如果我没有弄错的话,你可以在这行之后加入 gender_out = tf.argmax(tf.nn.softmax(gender_logits), 1,name='gender_out ') age_out = tf.reduce_sum(tf.multiply(tf.nn.softmax(age_logits), age_), axis=1,name='age_out ') 函数中的两个name就是你需要的output_node_names
或者你也可以看一下被注释掉的这段, 如果没有记错的话当时也是为了生成pb文件写的。
多谢,我这边因为机器配置低的原因,现训练模型所需时间太长了,想直接把ckpt的转化成pb的,因为我直接跑你的demo.py(摄像头调用的)成功了,对应的模型是savedmodel.ckpt,所以想把savedmodel.ckpt转换成pb的,在网上找了一下资料,有如下代码,
import tensorflow as tf import os.path import argparse from tensorflow.python.framework import graph_util
MODEL_DIR = "model/pb" MODEL_NAME = "frozen_model.pb"
if not tf.gfile.Exists(MODEL_DIR): #创建目录 tf.gfile.MakeDirs(MODEL_DIR)
def freeze_graph(model_folder): checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径
output_node_names = "gender_out" #原模型输出操作节点的名字
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
#print( saver.name)
graph = tf.get_default_graph() #获得默认的图
input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
#print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字
output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定
sess,
input_graph_def,
output_node_names.split(",") #如果有多个输出节点,以逗号隔开
)
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
for op in graph.get_operations():
print(op.name, op.values())
if name == 'main': #parser = argparse.ArgumentParser() #parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型, # 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments #aggs = parser.parse_args() #freeze_graph(aggs.model_folder) freeze_graph("models2") #模型目录 这里也是需要原模型输出操作节点的名字,请问你还记得savedmodel对应性别和年龄判断对应节点的名称么
我当时训练的时候没有手动指定名称,所以名称应该是tf自动分配的,我可以晚些时候回去帮你看一下,可能要明天才能回复你。或者你直接尝试下我刚才的第二条回复应该也是可行的。
没关系的,麻烦你了,你说的第二条方法成功生成pb文件了,多谢 我测试一下看看效果再
因为我想把模型迁移到安卓上面,安卓端需要下面这三个变量名称,可以在train.py中添加进模型么? 模型中输入变量的名称(最开始的入口节点,是demo.py中第95行的input_image么?) 模型中输出变量的名称(最后的输出节点,没有找到) 输入值的大小(我看了一下demo.py中的95行,是[160, 160, 3]这个大小么?长宽三颜色通道数) 刚开始接触不到3天,望理解小白,非常感谢
好的,多谢指点,这两个名称能在train.py训练的时候就固化加入到pb模型中么?
事实上,这两个参数并不是网络的一部分。因此不需要也不可能加入到模型中。这里是网络最后的输出,后面的年龄和性别的判断可以在测试的时候直接构建。
嗯,好的,我再看看怎样才能迁移到安卓端吧,多谢了。
@rochou 老哥方便给个邮箱或者其他联系方式么,迁移到安卓端的方法我也有一些疑惑,能不能交流一下?
@shaonfu 你留个qq号吧 我加你
@rochou 362483671
@rochou @shaonfu 你们迁移到安卓端了吗
@315386775,我没有迁移,你们呢?
@rochou 请问上面你说的转成pb文件,采用第二个回复注释掉方法生成的pb是只包含节点的pb吧,因为我看大小为4M,如果想生成frozen后的pb,还是绕不过需要output_name。 而且生成只包含节点的pb后,我想用tensorflow里面summarize_graph打印输入输出节点信息,但是显示一堆错误,还是得不到name。 另外我在train的时候,加了保存graph的一句话,使用output_name为save/save_all(忘了使用哪种方法print出来的一个name),确实生成了一个很大的pb文件,但是使用summarize_graph打印这个pb文件节点信息的时候,没报错,但是input_name和output_name全部为空。 上面第二个回复提到的生成pb文件的过程,我看别的人说的,貌似是一种生成模型和使用模型解耦的意思,其实还是没有output_name,我尝试了age、ages、age_logits等name试了,都说graph中不存在这个name。 或者可以重新训练的话,这个output_name可以在哪里指定吗,在inception_resnet_v1网络里面,还是在train里面呢?按照上面第一条回复,加入 gender_out = tf.argmax(tf.nn.softmax(gender_logits), 1,name='gender_out ') age_out = tf.reduce_sum(tf.multiply(tf.nn.softmax(age_logits), age_), axis=1,name='age_out ') 还是遇到了错误,加入这两句话后,gender_out和age_out是加入到Graph中了吗?
@bityangbing 这个我也不清楚,最后我采用的是谷歌官方迁移学习demo来实现区分男女的
@bityangbing 我用tensorboard 看graph 可以找到节点并且试着输入output_node_name 可以生成一个大约90M的pb文件 但是接下去这个pb文件不知道可不可以直接迁移安卓里面去...
@shaonfu 我刚接触tf几天,还没用过tensorboard,请问,你说的"试着输入output_node_name"意思是给他增加了一个output_name?还是用tensorboard看到了他最后节点name,并当做了output_name?你还记得output_name是什么吗,他应该是两个输出吧,是两个name组成的list吗?
@bityangbing 看到了他最后的节点 有个叫AddN的节点你可以试试当做output_node_name 生成一个pb文件
模型的输入输出节点可以发出来吗
在 eval.py中加入打印就可以得到输出节点名了:
print(age_logits, age_logits.get_shape())
#Tensor("logits/age/BiasAdd:0", shape=(?, 101), dtype=float32) (?, 101)
print(gender_logits, gender_logits.get_shape())
#Tensor("logits/gender/BiasAdd:0", shape=(?, 2), dtype=float32) (?, 2)
@bityangbing 这个我也不清楚,最后我采用的是谷歌官方迁移学习demo来实现区分男女的
Demo在哪里?给个Link,谢谢