如何将PyTorch模型转换为TensorFlow模型

加载之前训练好的模型,继续训练防止程序意外退出拿不到训练结果。

(一)Tensorflow模型介绍

通常我们训练好中后都会得到这样几个文件
tensorflow如何加载pytorch模型 tensorflow加载训练好的模型_加载


1).meta 文件

是一个协议缓冲区,可以保存完整的 Tensorflow图  即所有 变量,操作,集合  等。此文件具有.meta扩展名

2).data 文件

.data 文件是一个二进制文件,包括 权重,偏差,渐变和所有其他保存变量的所有值  。.data-00000of00001只是后缀,加载的时候不用写,只写model.ckpt即可。详情见后面

其中,-27150表示第27150次训练得到的结果

(二)模型的保存(保存所有的所有参数的图形和值)

1、 首先要建立一个saver对象:如

登录后复制

saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

max_to_keep 表示保存模型的个数,max_to_keep=5表示保存最新的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

登录后复制

saver = tf.train.Saver( max_to_keep=0)

但是这样不推荐,一般不这样做,浪费存储空间


2、 创建完saver对象后,使用saver.save()就可以保存训练好的模型了,如:

登录后复制

checkpoint_path='/Users/model/ade20k/model.ckpt'
#checkpoint_path  模型存储地址 后缀为.ckpt
saver.save(self.sess, checkpoint_path,global_step=step)
print('The checkpoint has been created, step: {}'.format(step))#可以打印出来

self.sess是创建的会话,因为所有的变量仅在会话中存在。因此,您必须通过在刚创建的saver对象上调用save方法将模型保存在会话session中。第二个参数设置保存的路径和名字,第三个参数global_step将训练的次数作为后缀加入到模型名字中。
完整的保存程序应该是


登录后复制

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()#创建对象saver
sess = tf.Session()#创建会话
sess.run(tf.global_variables_initializer())#初始化会话所有变量
saver.save(sess, 'my_test_model',global_step=1000)#将第1000次训练的模型保存到会话中

可根据需求,个性化保存训练文件  。

1)让我们说,在训练时,我们在每1000次迭代后保存我们的模型,所以.meta文件是第一次创建(第1000次迭代),我们不需要每次都重新创建.meta文件(所以,我们不要t保存.meta文件在2000,3000 …或任何其他迭代)。我们只保存模型以进行进一步的迭代,因为图形不会改变。因此,当我们不想编写元图时,我们使用这个:

登录后复制

saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)


2)如果您只想保留4个最新型号并希望在培训期间每2小时保存一个型号,则可以使用max_to_keep和keep_checkpoint_every_n_hours。

登录后复制

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)


3)最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代

登录后复制

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100)
    batch_xs,batch_ys=mnist.train.next_batch(100)
    sess.run(train_op,feed_dict={x: batch_xs,y_: batch_ys})
    val_loss,val_acc=sess.run([loss,acc],feed_dict={x:mnist.test.images,y_:mnist.test.labels})
    print('epoch:%d, val_loss:%f,val_acc:%f'%(i,val_loss,val_acc))
    if val_acc>max_acc:
        max_acc=val_ac        
        saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()


4)如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

登录后复制

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

(三)、模型的加载与恢复

如果想使用其他人的预训练模型进行微调,需要做两件事:


1)创建网络结构  。
可以通过编写python代码来创建网络,以手动创建每个图层作为原始模型。但是,我们已将网络保存在.meta文件中,我们可以使用tf.train.import()函数重新创建网络,如下所示

登录后复制

new_saver = tf.train.import_meta_graph("/Users/yanni_Z/python/model.ckpt-100.meta")

import_meta_graph将.meta文件中定义的网络加载到当前图形。创建图形/网络,接着还需要加载我们在此图表上训练过的参数值。


2)加载参数

用restore()函数,它需要两个参数restore(sess, save_path)tf.train.latest_checkpoint()来自动获取最后一次保存的模型

登录后复制

with tf.Session() as sess:
 new_saver =tf.train.import_meta_graph('my_test_model-1000.meta')
 new_saver.restore(sess,tf.train.latest_checkpoint('./'))
 

或者加载特定的训练模型

登录后复制

with tf.Session() as sess:
 tf.train.import_meta_graph("/Users/yanni_Z/python/model.ckpt-100.meta")#加载网络结构
 loader.restore(sess,"/Users/yanni_z/python/model/odel.ckpt-100")#载入权重等参数
 

(四)使用已经恢复的模型

如果想使用已经载入的模型进行 预测、微调,进一步训练。 使用Tensorflow时,先定义一个图表,其中包含示例(训练数据)和一些超参数,如学习率,全局步骤等。使用占位符提供所有训练数据和超参数的标准做法。让我们使用占位符构建一个小型网络并保存它。 或者加载特定的训练模型

登录后复制

import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

当我们想要恢复它时,我们不仅要恢复图形和权重,还要准备一个新的feed_dict,将新的训练数据提供给网络。我们可以通过graph.get_tensor_by_name()方法引用这些保存的操作和占位符变量。


登录后复制

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")
## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

如果我们只想使用不同的数据运行相同的网络,您只需通过feed_dict将新数据传递到网络即可。


登录后复制

import tensorflow as tf
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1.

如果想在原来的网路图中添加更多操作,该怎么办?当然你也可以这样做。看这里:


登录后复制

import tensorflow as tf
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
print sess.run(add_on_op,feed_dict)
#This will print 120.

可以恢复旧图形的一部分并对其进行附加以进行微调?当然,您可以通过graph.get_tensor_by_name()方法访问相应的操作,并在其上构建图形。在这里,我们使用元图加载一个vgg预训练网络,并在最后一层将输出数量更改为2,以便使用新数据进行微调。


登录后复制

saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()


免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空