TensorFlow模型持久化存储技巧

模型持久化的目的在于可以使模型训练后的结果重复使用,节省重复训练模型的时间。模型保存

train.Saver类是TensorFlow提供的用于保存和还原模型的API,使用非常简单。

登录后复制

import tensorflow as tf


# 声明两个变量并计算其加和
a = tf.Variable(tf.constant([1.0, 2.0], shape=[2]), name='a')
b = tf.Variable(tf.constant([3.0, 4.0], shape=[2]), name='b')
result = a + b

# 初始化全部变量的操作
init_op = tf.global_variables_initializer()
# 定义 Saver 类对象用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "./model/model.ckpt")

上面的代码实现了一个简单的TensorFlow模型持久化的功能。

save()函数的sess参数用于指定要保存的模型会话,save_path参数用于指定路径。

通过Saver类的save()函数将TensorFlow模型保存到一个指定路径下的model.ckpt文件中。

(TensorFlow模型一般会保存在文件名为.ckpt的文件中,可以省略后缀名,但是好的编程习惯是对其加以指定)

虽然上面的程序只制定了一个文件路径,但是在这个文件目录下回出现4个文件:
TensorFlow模型持久化_tensorflow

  • checkpoint文件是一个文本文件,保存了一个目录下所有的模型文件列表。该文件会被自动更新,当有更多模型被保存到model目录下时,文件内容会更新为最新的训练模型。
  • model.ckpt.data-00000-of-00001文件是一个二进制文件,保存了TensorFlow中每一个变量的取值。
  • model.ckpt.index文件是一个二进制文件,保存了每一个变量的名称,是一个string-string的table,其中table的key值为tensor名,value值为BundleEntryProto。
  • model.ckpt.meta文件是一个二进制文件,保存了计算图的结构。

将一个模型文件分成多个文件保存的原因是TensorFlow会将模型的计算图结构以及参数的取值分开来保存。模型加载

TensorFlow也提供了相应的函数来加载保存的模型。

登录后复制

with tf.Session() as sess:
    saver.restore(sess, "./model/model.ckpt")
    print(sess.run(result))

输出:
TensorFlow模型持久化_加载_02
加载模型的代码和保存模型的代码相似,但是省略了初始化全部变量的过程。

使用restore()函数需要在模型参数恢复前定义计算图上的所有运算,并且变量名需要与模型存在的变量名保持一致,这样就可以将变量的值通过已保存的模型加载进来。

有时我们可能不希望重复定义计算图上的计算,太繁琐了,TensorFlow提供了import_meta_graph()函数加载模型的计算图。

import_meta_graph()函数的输入参数为.meta文件的路径,返回一个Saver类实例,再调用这个实例的restore()函数就可以恢复参数了。

登录后复制

saver = tf.train.import_meta_graph("./model/model.ckpt.meta")

with tf.Session() as sess:
    saver.restore(sess, "./model/model.ckpt")
    # 获取默认计算图上指定节点处的张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

输出:
TensorFlow模型持久化_加载_03
.ckpt.meta文件保存了计算图的结构,通过import_meta_graph()函数将计算图导入到程序中并传递给saver,之后在会话中通过restore()函数对该计算图中变量的值进行加载。

get_tensor_by_name()函数用于获取指定节点处的张量(add:0 表示add节点的第一个输出)。



免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空