TensorFlow - 模型持久化与状态恢复
在机器学习工作流中,模型的保存与加载是关键环节。通过持久化机制,可避免重复训练,提升开发效率,并支持断点续训。TensorFlow 提供了多种方式实现模型状态的存储与还原。
1. 训练过程中的检查点管理
训练期间自动创建检查点,将权重以二进制格式保存至文件集合。这些文件仅包含参数值,不包含模型结构或优化器状态。使用 ModelCheckpoint 回调可实现自动化保存:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint_path = "checkpoints/model-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
verbose=1,
period=5 # 每5个训练周期保存一次
)
通过将此回调传入 model.fit(),系统将在指定周期后生成对应命名的权重文件,便于后续恢复。
2. 手动保存权重
若需在特定时刻保存模型参数,可直接调用 save_weights 接口:
model = create_model()
model.fit(x_train, y_train, epochs=10)
model.save_weights('./saved_weights/my_model')
该方法适用于仅需保留训练结果、无需重建完整模型的场景。
3. 完整模型的序列化
为实现端到端的可复现性,可将整个模型(包括架构、权重、优化器配置)保存为单个文件。推荐使用 HDF5 格式:
model.save('full_model.h5') # 保存完整模型
loaded_model = tf.keras.models.load_model('full_model.h5') # 加载并恢复所有状态
注意:当前版本中,基于 tf.train 的优化器无法被完整保存,需在加载后重新编译模型。
4. 数据准备:MNIST 手写数字识别
本例使用经典的 MNIST 数据集,包含 60,000 张训练图像和 10,000 张测试图像,每张为 28×28 像素的灰度图。数据预处理包括归一化与形状重塑:
train_images = train_images.reshape(-1, 784) / 255.0
test_images = test_images.reshape(-1, 784) / 255.0
5. 模型定义与训练流程
构建一个简单的全连接网络用于分类任务:
def build_network():
model = tf.keras.Sequential([
tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
训练完成后,可通过不同方式加载状态并验证性能,实现从"未训练"到"已恢复"的准确率跃升。
6. 文件结构说明
- 检查点文件:由
.index和.data-*组成,如cp-0010.ckpt.index与cp-0010.ckpt.data-00000-of-00001 - HDF5 模型文件:
.h5格式,封装全部信息,体积较大但使用便捷 - 手动保存权重:生成独立的权重文件,适合轻量级部署
7. 注意事项与警告处理
出现以下警告属正常现象:
"This model was compiled with a Keras optimizer ... but is being saved in TensorFlow format with
save_weights."
原因:使用的是 Keras 优化器而非 TensorFlow 原生优化器,导致其状态无法随权重一同保存。解决方案是:
- 若需保留优化器状态,应改用
tf.keras.optimizers中的兼容版本; - 或接受状态丢失,仅恢复权重。
该警告不影响模型功能,可忽略。