TensorFlow 2.0模型保存方法

介绍

​​模型保存有5种:1、整体保存;2、网络架构保存;3、权重保存;4、回调保存;5、自定义训练模型的保存​​

​​1、整体保存:权重值,模型配置(架构),优化器配置​​

​整个模型可以保存到一个文件中,其中包含权重值、模型配置乃至优化器配置。这样,您就可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。​

​在Keras中保存完全可以正常使用的额模型非常有用,您可以在TensorFlow.js中加载它们,然后在网络浏览器中训练和运行它们。​

​Keras使用HDF5标准提供基本的保存格式。​



登录后复制

# -*- coding: UTF-8 -*-"""Author: LGDFileName: save_modelDateTime: 2020/12/23 16:21 
SoftWare: PyCharm"""# the whole model saveimport tensorflow as tfimport osimport 
pandasimport numpy as npimport matplotlib.pyplot as plt(train_image, train_label), 
(test_image, test_label) = tf
.keras.datasets
.fashion_mnist.load_data()train_image = train_image / 255test_image = test_image / 255
# model = tf.keras
.Sequential()# model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
# model.add(tf.keras.layers
.Dense(128, activation='relu'))# model.add(tf.keras.layers.Dense(10, activation='softmax'))
## model
.summary()## model.compile(#     optimizer='adam',#     loss='sparse_categorical_crossentropy',
#     metrics=['acc']# )## model.fit(train_image, train_label, epochs=5)# save model# model
.save('less_model.h5')# use the modelnew_model = tf
.keras.models.load_model('less_model.h5')new_model
.summary()# evaluate the modeleva_result = new_model
.evaluate(test_image, test_label, verbose=0)  
# 0表示不显示提示print('evaluate result: ', eva_result)1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.
16.17.18.19.20.21.22.23.24.25.26.27.



2、仅保存架构

有时我们只对模型的架构感兴趣,而无需保存权重值或优化器。在这种情况下,可以仅保存模型的“配置”。

登录后复制

json_config = model.to_json()reinitialized_model = tf.keras.models
.model_from_json(json_config)reinitialized_model.summary()# 若要使用就需要配置reinitialized_model
.compile(    optimizer='adam',    loss='sparse_categorical_crossentropy',    
metrics=['acc'])1.2.3.4.5.6.7.8.9.



3、仅保存权重

有时我们只需要保存模型的状态(其权重值),而对模型架构不感兴趣。在这种情况下,可以通过get_weights()获取权重值,并通过set_weights()设置权重值。

登录后复制

weights = model.get_weights()print(weights)# 设置权重reinitialized_model.set_weights(weights)
# 保存权重model.save_weights('less_weights.h5')reinitialized_model.load_weights('less_weights.h5')
# 保存架构和保存权重合在一起并不能和保存整个模型等同,还有优化器配置没有保存。1.2.3.4.5.6.7.8.9.



4、在训练期间保存检查点

在训练期间或训练结束时自动保存检查点,这样一来,你便可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断。

回调函数:tf.keras.callbacks.ModelCheckpoint

登录后复制

# # 4、在训练期间保存检查点# checkpoint_path = 'cp.ckpt'
# cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True)#
# model = tf.keras.Sequential()# model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
# model.add(tf.keras.layers.Dense(128, activation='relu'))# model.add(tf.keras.layers
.Dense(10, activation='softmax'))## model.summary()## model.compile(#     optimizer='adam',
#     loss='sparse_categorical_crossentropy',#     metrics=['acc']# )#
# eva_result = model.evaluate(test_image, test_label, verbose=0)  
# 0表示不显示提示# print(eva_result)# # model.fit(train_image, train_label, epochs=5, 
callbacks=[cp_callback])## model.load_weights(checkpoint_path)# eva_result = model
.evaluate(test_image, test_label, verbose=0)  # 0表示不显示提示# print(eva_result)1.



5、在自定训练中保存模型

登录后复制

# 在自定义训练中保存检查点model = tf.keras.Sequential()model.add(tf.keras.layers
.Flatten(input_shape=(28, 28)))model.add(tf.keras.layers.Dense(128, activation='relu'))model
.add(tf.keras.layers.Dense(10, activation='softmax'))optimizer = tf.keras
.optimizers.Adam()loss_func = tf.keras.losses
.SparseCategoricalCrossentropy(from_logits=True)def
loss(model, x, y):    y_ = model(x)    return loss_func(y, y_)train_loss = tf
.keras.metrics.Mean('train_loss', dtype=tf.float32)train_accuracy = tf
.keras.metrics.SparseCategoricalAccuracy('train_accuracy')test_loss = tf
.keras.metrics.Mean('test_loss', dtype=tf.float32)test_accuracy = tf
.keras.metrics.SparseCategoricalAccuracy('test_accuracy')def 
train_step(model, images, labels):    
with tf.GradientTape() as t:        pred = model(images)        
loss_step = loss_func(labels, pred)    
grads = t.gradient(loss_step, model.trainable_variables)    
optimizer.apply_gradients(zip(grads, model
.trainable_variables))    train_loss(loss_step)    
train_accuracy(labels, pred)dataset = tf
.data.Dataset.from_tensor_slices((train_image, train_label))
dataset = dataset.shuffle(10000)
.batch(32)cp_dir = './custom_train_cp'cp_prefix = os
.path.join(cp_dir, 'ckpt')checkpoint = tf
.train.Checkpoint(    optimizer=optimizer,    model=model)# def train():#     
for epoch in range(5):
#         for (batch, (images, labels)) in enumerate(dataset):#             
train_step(model, images, labels)#         print('Epoch{} loss is {}'
.format(epoch, train_loss.result()))
#         print('Epoch{} accuracy is {}'.format(epoch, train_accuracy.result()))## 
train_accuracy.reset_states()#         train_loss.reset_states()##         
checkpoint.save(file_prefix=cp_prefix)### # 训练# train()
# 恢复模型print(checkpoint.restore(tf.train.latest_checkpoint(cp_dir)))
# 测试恢复的模型print(tf.argmax(model(train_image, training=False), axis=-1)
.numpy())print(train_label)print((tf.argmax(model(train_image, training=False), axis=-1)
.numpy() ==      train_label).sum()/len(train_label))1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18
.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.



免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空