当前位置:首页 > 技术 > 正文内容

TensorFlow - 模型持久化与状态恢复

访客 技术 2026年6月28日 1

在机器学习工作流中,模型的保存与加载是关键环节。通过持久化机制,可避免重复训练,提升开发效率,并支持断点续训。TensorFlow 提供了多种方式实现模型状态的存储与还原。

1. 训练过程中的检查点管理

训练期间自动创建检查点,将权重以二进制格式保存至文件集合。这些文件仅包含参数值,不包含模型结构或优化器状态。使用 ModelCheckpoint 回调可实现自动化保存:

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint_path = "checkpoints/model-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    verbose=1,
    period=5  # 每5个训练周期保存一次
)

通过将此回调传入 model.fit(),系统将在指定周期后生成对应命名的权重文件,便于后续恢复。

2. 手动保存权重

若需在特定时刻保存模型参数,可直接调用 save_weights 接口:

model = create_model()
model.fit(x_train, y_train, epochs=10)
model.save_weights('./saved_weights/my_model')

该方法适用于仅需保留训练结果、无需重建完整模型的场景。

3. 完整模型的序列化

为实现端到端的可复现性,可将整个模型(包括架构、权重、优化器配置)保存为单个文件。推荐使用 HDF5 格式:

model.save('full_model.h5')  # 保存完整模型
loaded_model = tf.keras.models.load_model('full_model.h5')  # 加载并恢复所有状态

注意:当前版本中,基于 tf.train 的优化器无法被完整保存,需在加载后重新编译模型。

4. 数据准备:MNIST 手写数字识别

本例使用经典的 MNIST 数据集,包含 60,000 张训练图像和 10,000 张测试图像,每张为 28×28 像素的灰度图。数据预处理包括归一化与形状重塑:

train_images = train_images.reshape(-1, 784) / 255.0
test_images = test_images.reshape(-1, 784) / 255.0

5. 模型定义与训练流程

构建一个简单的全连接网络用于分类任务:

def build_network():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

训练完成后,可通过不同方式加载状态并验证性能,实现从"未训练"到"已恢复"的准确率跃升。

6. 文件结构说明

  • 检查点文件:由 .index.data-* 组成,如 cp-0010.ckpt.indexcp-0010.ckpt.data-00000-of-00001
  • HDF5 模型文件.h5 格式,封装全部信息,体积较大但使用便捷
  • 手动保存权重:生成独立的权重文件,适合轻量级部署

7. 注意事项与警告处理

出现以下警告属正常现象:

"This model was compiled with a Keras optimizer ... but is being saved in TensorFlow format with save_weights."

原因:使用的是 Keras 优化器而非 TensorFlow 原生优化器,导致其状态无法随权重一同保存。解决方案是:

  • 若需保留优化器状态,应改用 tf.keras.optimizers 中的兼容版本;
  • 或接受状态丢失,仅恢复权重。

该警告不影响模型功能,可忽略。

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。