TensorFlow模型保存技巧:如何保存模型参数

一、保存、读取说明

我们创建好模型之后需要保存模型,以方便后续对模型的读取与调用,保存模型我们可能有下面三种需求:1、只保存模型权重参数;2、同时保存模型图结构与权重参数;3、在训练过程的检查点保存模型数据。下面分别对这三种需求进行实现。


二、仅保存模型参数

仅保存模型参数可以用一下的API:

Model.save_weights(file_path) # 将文件保存到save_path
Model.load_weights(file_path) # 将文件读取到save_path

注意:由于save_weights只是保存权重w、b的参数值,所以在加载时最好保证我们的模型结构和原来保存的模型结构是相同的,否则可能会报错。.

模型在保存之后会有多个文件:

  • index类型文件,在分布式计算中,索引文件会指示哪些权重存储在哪个分片。
  • checkpoint类型文件,检查文件点包含: 一个或多个包含模型权重的分片
  • 如果您只在一台机器上训练模型,那么您将有一个带有后缀的分片:.data-00000-of-00001



import tensorflow as tf
import os

# 读取数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# 数据集归一化
train_images = train_images / 255
train_labels = train_labels / 255  # 进行数据的归一化,加快计算的进程

# 创建模型结构
net_input=tf.keras.Input(shape=(28,28))
fl=tf.keras.layers.Flatten()(net_input)#调用input
l1=tf.keras.layers.Dense(32,activation="relu")(fl)
l2=tf.keras.layers.Dropout(0.5)(l1)
net_output=tf.keras.layers.Dense(10,activation="softmax")(l2)

# 创建模型类
model = tf.keras.Model(inputs=net_input, outputs=net_output)

# 查看模型的结构
model.summary()

# 模型编译
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc'])

# 模型训练
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1)

# 模型存放路径
save_path = './save_weights/'
model.save_weights(save_path)

# 模型加载
model.load_weights(save_path)

# # 定义一个与原模型结构不同的模型
# net_in=tf.keras.Input(shape=(748,))
# net_out=tf.keras.layers.Dense(10,activation="softmax")(net_in)
# 
# # 用不同结构的模型读取参数,这里会报错
# model2=tf.keras.Model(inputs=net_in,outputs=net_out)
# model2.load_weights(save_path)

三、同时保存结构与参数

Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。

保存完整的模型非常有用,使我们可以在TensorFlow.js(HDF5, Saved Model) 中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)将它们转换为在移动设备上运行。

# 模型训练
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1)

# 保存模型
model.save('net_model.h5')

# 模型加载
new_model=tf.keras.models.load_model('net_model.h5')

四、在训练过程的检查点保存模型数据

在训练过程的检查点保存模型数据有两个作用:1、我们可以保存训练各个节点的数据,便于我们把训练效果最好的节点的模型挑选出来。2、可以随时先暂停训练模型,当想要训练时继续训练。

在训练的检查点保存模型需要用到tf.keras.callbacks.ModelCheckpoint()类,这个是一个回调类,可以以列表形式传入到fit()方法的callbacks参数中。

回调中类,文件名以.ckpt作为后缀,如文件路径'./checkpoint/train.ckpt',会在checkpoint生成三个文件,后缀与Model.save_weights()方法创建的文件后缀相同,意义也相同。以下为回调类的参数:

tf.keras.callbacks.ModelCheckpoint()

  • filepath:string,保存模型文件的路径。
  • monitor:监控:要监控的数量。
  • verbose详细:详细模式,0或1。
  • save_best_only:如果save_best_only = True,则不会覆盖根据监控数量的最新最佳模型。
  • save_weights_only:如果为True,则只有模型的权重
    保存(model.save_weights(filepath)),否则保存完整模型(model.save(filepath))。
  • mode:{auto,min,max}之一。 如果save_best_only =
    True,则根据监控数量的最大化或最小化来决定覆盖当前保存文件。
    对于val_acc,这应该是max,对于val_loss,这应该是min等。在自动模式下,从监控量的名称自动推断方向。
  • period:检查点之间的间隔(时期数)。

  

# 模型编译
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc'])

# 创建一个保存模型的回调函数,每5个周期保存一次权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='./checkpoint/train.ckpt',
    verbose=1,
    save_weights_only=True,
    period=5
)

# 模型训练
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1, callbacks=[cp_callback])

# 加载模型
model.load_weights('./checkpoint/train.ckpt')

# # 继续训练模型
# model.fit()

五、模型可训练参数的提取

有时候我们需要查看模型的参数,但是模型参数的显示有时候由于数据过多不能再控制台全部显示,所以需要存放到文件来查看。以下是提取可训练参数的方法:

# 参看可训练参数
import numpy as np
model.trainable_variables

# 设置全部可训练参数可打印,不然数据过多,有一部分会以省略号的形式显示
np.set_printoptions(threshold=np.inf)

# 可训练参数保存到文件
with open('trainable.txt', mode='w',encoding = "utf-8") as f:
    for t_v in model.trainable_variables:
        f.writelines(str(t_v.name) + '\n')  # 保存参数名字
        f.writelines(str(t_v.shape) + '\n')  # 保存参数形状
        f.writelines(str(t_v.numpy()) + '\n')  # 保存参数数值
        

   

免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

相关推荐
技术文档
软件下载
QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空