TensorFlow模型保存与加载方法全解析

一、TensorFlow常规模型加载方法


保存模型

tf.train.Saver()类,.save(sess, ckpt文件目录)方法


参数名称功能说明默认值
var_listSaver中存储变量集合全局变量集合
reshape加载时是否恢复变量形状True
sharded是否将变量轮循放在所有设备上True
max_to_keep保留最近检查点个数5
restore_sequentially是否按顺序恢复变量,模型较大时顺序恢复内存消耗小True

var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。

如果Saver给定了字典作为加载方式,则按照字典来,​​如:saver ​​​​=​​​ ​​tf.train.Saver({​​​​"v/ExponentialMovingAverage"​​​​:v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。​


加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

TensorFlow模型保存和载入方法汇总_tensorflow

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:

登录后复制

​​ckpt ​​​​=​​​ ​​tf.train.get_checkpoint_state(​​​​'./model/'​​​​)​​​​print​​​​(ckpt.model_checkpoint_path)​​1.2.

TensorFlow模型保存和载入方法汇总_加载_02

.meta文件保存了当前图结构.data文件保存了当前参数名和值.index文件保存了辅助索引信息

.data文件可以查询到参数名和参数值,使用下面的命令可以查询保存在文件中的全部变量{名:值}对,

登录后复制

​​from​​​ ​​tensorflow.python.tools.inspect_checkpoint ​​​​import​​​ ​​print_tensors_in_checkpoint_file​​​​print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),​​​​None​​​​,​​​​True​​​​)​​1.2.

tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象

登录后复制

​​ckpt ​​​​=​​​ ​​tf.train.get_checkpoint_state(​​​​'./model/'​​​​)​​1.

tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量

登录后复制

​​saver ​​​​=​​​ ​​tf.train.Saver({​​​​"v/ExponentialMovingAverage"​​​​:v})  ​​1.

saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载

登录后复制

​​saver.restore(sess,​​​​'./model/model.ckpt-0'​​​​)​​​​saver.restore(sess,ckpt.model_checkpoint_path)​​1.2.


1.不加载图结构,只加载参数

由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。

登录后复制

​​'''​​​​使用原网络保存的模型加载到自己重新定义的图上​​​​可以使用python变量名加载模型,也可以使用节点名​​​​'''​​​​import​​​ ​​AlexNet as Net​​​​import​​​ ​​AlexNet_train as train​​​​import​​​ ​​random​​​​import​​​ ​​tensorflow as tf​​ ​​IMAGE_PATH ​​​​=​​​ ​​'./flower_photos/daisy/5673728_71b8cb57eb.jpg'​​ ​​with tf.Graph().as_default() as g:​​ ​​x ​​​​=​​​ ​​tf.placeholder(tf.float32, [​​​​1​​​​, train.INPUT_SIZE[​​​​0​​​​], train.INPUT_SIZE[​​​​1​​​​], ​​​​3​​​​])​​​​y ​​​​=​​​ ​​Net.inference_1(x, N_CLASS​​​​=​​​​5​​​​, train​​​​=​​​​False​​​​)​​ ​​with tf.Session() as sess:​​​​# 程序前面得有 Variable 供 save or restore 才不报错​​​​# 否则会提示没有可保存的变量​​​​saver ​​​​=​​​ ​​tf.train.Saver()​​ ​​ckpt ​​​​=​​​ ​​tf.train.get_checkpoint_state(​​​​'./model/'​​​​)​​​​img_raw ​​​​=​​​ ​​tf.gfile.FastGFile(IMAGE_PATH, ​​​​'rb'​​​​).read()​​​​img ​​​​=​​​ ​​sess.run(tf.expand_dims(tf.image.resize_images(​​​​tf.image.decode_jpeg(img_raw),[​​​​224​​​​,​​​​224​​​​],method​​​​=​​​​random.randint(​​​​0​​​​,​​​​3​​​​)),​​​​0​​​​))​​ ​​if​​​ ​​ckpt ​​​​and​​​ ​​ckpt.model_checkpoint_path:​​​​print​​​​(ckpt.model_checkpoint_path)​​​​saver.restore(sess,​​​​'./model/model.ckpt-0'​​​​)​​​​global_step ​​​​=​​​ ​​ckpt.model_checkpoint_path.split(​​​​'/'​​​​)[​​​​-​​​​1​​​​].split(​​​​'-'​​​​)[​​​​-​​​​1​​​​]​​​​res ​​​​=​​​ ​​sess.run(y, feed_dict​​​​=​​​​{x: img})​​​​print​​​​(global_step,sess.run(tf.argmax(res,​​​​1​​​​)))​​1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.


 2.加载图结构和参数

登录后复制

​​'''​​​​直接使用使用保存好的图​​​​无需加载python定义的结构,直接使用节点名称加载模型​​​​由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错​​​​现阶段不推荐使用,以后如果理解深入了可能会找到使用方法​​​​'''​​​​import​​​ ​​AlexNet_train as train​​​​import​​​ ​​random​​​​import​​​ ​​tensorflow as tf​​ ​​IMAGE_PATH ​​​​=​​​ ​​'./flower_photos/daisy/5673728_71b8cb57eb.jpg'​​  ​​ckpt ​​​​=​​​ ​​tf.train.get_checkpoint_state(​​​​'./model/'​​​​)                          ​​​​# 通过检查点文件锁定最新的模型​​​​saver ​​​​=​​​ ​​tf.train.import_meta_graph(ckpt.model_checkpoint_path ​​​​+​​​​'.meta'​​​​)   ​​​​# 载入图结构,保存在.meta文件中​​ ​​with tf.Session() as sess:​​​​saver.restore(sess,ckpt.model_checkpoint_path)                        ​​​​# 载入参数,参数保存在两个文件中,不过restore会自己寻找​​ ​​img_raw ​​​​=​​​ ​​tf.gfile.FastGFile(IMAGE_PATH, ​​​​'rb'​​​​).read()​​​​img ​​​​=​​​ ​​sess.run(tf.image.resize_images(​​​​tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method​​​​=​​​​random.randint(​​​​0​​​​, ​​​​3​​​​)))​​​​imgs ​​​​=​​​ ​​[]​​​​for​​​ ​​i ​​​​in​​​ ​​range​​​​(​​​​128​​​​):​​​​imgs.append(img)​​​​print​​​​(sess.run(tf.get_default_graph().get_tensor_by_name(​​​​'fc3:0'​​​​),feed_dict​​​​=​​​​{​​​​'Placeholder:0'​​​​: imgs}))​​ ​​'''​​​​img ​​​​=​​​ ​​sess.run(tf.expand_dims(tf.image.resize_images(​​​​tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method​​​​=​​​​random.randint(​​​​0​​​​, ​​​​3​​​​)), ​​​​0​​​​))​​​​print​​​​(img)​​​​imgs ​​​​=​​​ ​​[]​​​​for​​​ ​​i ​​​​in​​​ ​​range​​​​(​​​​128​​​​):​​​​imgs.append(img)​​​​print​​​​(sess.run(tf.get_default_graph().get_tensor_by_name(​​​​'conv1:0'​​​​),​​​​feed_dict​​​​=​​​​{​​​​'Placeholder:0'​​​​:img}))​​1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。


 3.简化版本

登录后复制

​​# 连同图结构一同加载​​​​ckpt ​​​​=​​​ ​​tf.train.get_checkpoint_state(​​​​'./model/'​​​​)​​​​saver ​​​​=​​​ ​​tf.train.import_meta_graph(ckpt.model_checkpoint_path ​​​​+​​​​'.meta'​​​​)​​​​with tf.Session() as sess:​​​​saver.restore(sess,ckpt.model_checkpoint_path)​​ ​​# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值​​​​# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错​​​​saver ​​​​=​​​ ​​tf.train.Saver()​​​​with tf.Session() as sess:​​​​ckpt ​​​​=​​​ ​​tf.train.get_checkpoint_state(​​​​'./model/'​​​​)​​​​saver.restore(sess,ckpt.model_checkpoint_path)​​1.2.3.4.5.6.7.8.9.10.11.12.


二、TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

登录后复制

​​# 新建空白图​​​​self​​​​.graph ​​​​=​​​ ​​tf.Graph()​​​​# 空白图列为默认图​​​​with ​​​​self​​​​.graph.as_default():​​​​# 二进制读取模型文件​​​​with tf.gfile.FastGFile(os.path.join(model_dir,model_name),​​​​'rb'​​​​) as f:​​​​# 新建GraphDef文件,用于临时载入模型中的图 ​​​​graph_def ​​​​=​​​ ​​tf.GraphDef()​​​​# GraphDef加载模型中的图​​​​graph_def.ParseFromString(f.read())​​​​# 在空白图中加载GraphDef中的图​​​​tf.import_graph_def(graph_def,name​​​​=​​​​'')​​​​# 在图中获取张量需要使用graph.get_tensor_by_name加张量名​​​​# 这里的张量可以直接用于session的run方法求值了​​​​# 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量​​​​self​​​​.input_tensor ​​​​=​​​ ​​self​​​​.graph.get_tensor_by_name(​​​​self​​​​.input_tensor_name)​​​​self​​​​.layer_tensors ​​​​=​​​ ​​[​​​​self​​​​.graph.get_tensor_by_name(name ​​​​+​​​ ​​':0'​​​​) ​​​​for​​​ ​​name   ​​​​in​​​ ​​self​​​​.layer_operation_names]​​1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.

上面两篇都使用了二进制加载模型的方式


三、二进制模型制作

这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。

tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,我们可以将tf的图和参数文件整合进一个后缀为pb的二进制文件中,由于整合过程回将变量转化为常量,所以我们在日后读取模型文件时不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称。

将图变量转换为常量的API:​ ​tf.graph_util.convert_variables_to_constants​

转换后的graph_def对象转换为二进制数据(graph_def.SerializeToString())后,写入pb即可。

登录后复制

​​import​​​ ​​tensorflow as tf​​ ​​v1 ​​​​=​​​ ​​tf.Variable(tf.constant(​​​​1.0​​​​, shape​​​​=​​​​[​​​​1​​​​]), name​​​​=​​​​'v1'​​​​)​​​​v2 ​​​​=​​​ ​​tf.Variable(tf.constant(​​​​2.0​​​​, shape​​​​=​​​​[​​​​1​​​​]), name​​​​=​​​​'v2'​​​​)​​​​result ​​​​=​​​ ​​v1 ​​​​+​​​ ​​v2​​ ​​saver ​​​​=​​​ ​​tf.train.Saver()​​​​with tf.Session() as sess:​​​​sess.run(tf.global_variables_initializer())​​​​saver.save(sess, ​​​​'./tmodel/test_model.ckpt'​​​​)​​​​gd ​​​​=​​​ ​​tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), [​​​​'add'​​​​])​​​​with tf.gfile.GFile(​​​​'./tmodel/model.pb'​​​​, ​​​​'wb'​​​​) as f:​​​​f.write(gd.SerializeToString())​​1.2.3.4.5.6.7.8.9.10.11.12.13.


我们可以直接查看gd:

登录后复制

node {   name: "v1"   op: "Const"   attr {     key: "dtype"     value {       type: DT_FLOAT     }   }   attr {     key: "value"     value {       tensor {         dtype: DT_FLOAT         tensor_shape {           dim {             size: 1           }         }         float_val: 1.0       }     }   } } …… node {   name: "add"   op: "Add"   input: "v1/read"   input: "v2/read"   attr {     key: "T"     value {       type: DT_FLOAT     }   } } library {1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.


四、从图上读取张量

上面的代码实际上已经包含了本小节的内容,但是由于从图上读取特定的张量是如此的重要,所以我仍然单独的补充上这部分的内容。

无论如何,想要获取特定的张量我们必须要有张量的名称图的句柄,比如 'import/pool_3/_reshape:0' 这种,有了张量名和图,索引就很简单了。



从二进制模型加载张量

第二小节的代码很好的展示了这种情况


免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空