TensorFlow模型保存与加载的实用技巧

我们在训练模型时,要将 tensorflow 训练好的模型参数保存起来,以便以后进行验证或测试。此时就涉及到将模型持久化,和模型的再次读取。tensorflow中提供了这些类。



模型保存 tf.train.Saver()

  • 模型保存,先要创建一个Saver对象。
    在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0。
    当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可。(其实我们只需要保存最后一个就好)
  • write_meta_graph:bool型,是否把TensorFlow的图保存下来
  • 创建完saver对象后,就可以保存训练好的模型了。
  • 登录后复制
saver=tf.train.Saver(max_to_keep=1)saver.save(sess,'./name.ckpt',global_step=step)# 参数: 
保存的sess,路径和名字,训练的次数作为后缀加入到模型名字中# saver.save(sess, 'model.ckpt', global_step=1000) ==> 
filename: 'model-1000'1.2.3.4.

TensorFlow之六—模型保存与加载_git

  • meta:保存了TensorFlow的graph。包括all variables,operations,collections等等。
  • checkpoint:二进制文件,保存了所有weights,biases,gradient and all the other variables的值。也就是上图中的.data-00000-of-00001和.index文件。.data文件包含了所有的训练变量。与此同时,Tensorflow还有一个名为checkpoint的文件,只保存最新检查点文件的记录,即最新的保存路径。

模型恢复 restore()

模型恢复​​restore(sess, save_path)​​函数:

 ​​sess​​​:模型

​​​save_path​​:模型保存的模型路径。使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型


登录后复制

model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)1.2.




登录后复制
sess = tf.Session()#加载整个graph:不需要重新构建整个graph了new_saver = tf.train.import_meta_graph('../model/model_LR_test.meta')#加载模型中各种变量的值,注意这里不用文件的后缀new_saver.restore(sess,'../model/model_LR_test') #对应第一个文件的add_to_collection()函数#返回值是一个list,我们要的是第一个,这也说明可以有多个变量的名字一样。hyp = tf.get_collection('hypothesis')[0]graph = tf.get_default_graph() X = graph.get_operation_by_name('X').outputs[0]#为了将placeholder加载出来pred = sess.run(hyp,feed_dict = {X:x_valid})print('auc:',auc(y_valid,pred))1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.

小案例

登录后复制

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784])y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x,                          units=1024,                          activation=tf.nn.relu,                         kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),                         kernel_regularizer=tf.nn.l2_loss)dense2= tf.layers.dense(inputs=dense1,                          units=512,                          activation=tf.nn.relu,                         kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),                         kernel_regularizer=tf.nn.l2_loss)logits= tf.layers.dense(inputs=dense2,                         units=10,                         activation=None,                        kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),                        kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)    acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession()  sess.run(tf.global_variables_initializer())is_train=Truesaver=tf.train.Saver(max_to_keep=3)#训练阶段if is_train:    max_acc=0    f=open('ckpt/acc.txt','w')         # 1.新建txt文件记录训练acc与loss过程    for i in range(100):      batch_xs, batch_ys = mnist.train.next_batch(100)      sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})      val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})      print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))      f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')  # 2.写入文本      if val_acc>max_acc:             # 确保保存最好的精度          max_acc=val_acc          saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)    f.close()                         # 3.关闭文本#验证阶段else:    model_file=tf.train.latest_checkpoint('ckpt/')    saver.restore(sess,model_file)   # 模型恢复    val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})    print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))sess.close()1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.

虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内)。

         

免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空