当前位置:首页 > 工具 > 正文内容

TensorFlow实战:构建情感分析模型处理IMDB影评

访客 工具 2026年6月9日 1

项目概览

本文演示如何使用TensorFlow和Keras构建一个文本分类模型,对IMDB电影评论进行正面/负面情感分类。完整代码可在GitHub仓库获取。

数据准备与探索

加载数据集

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
print(f"TensorFlow: {tf.__version__}, Keras: {tf.keras.__version__}")

data_dir = str(pathlib.Path.cwd()) + "/datasets/imdb/"

# 检查预下载的numpy文件
np_data = np.load(data_dir + "imdb.npz")
print("Keys in npz:", list(np_data.keys()))

# 加载IMDB数据集,仅保留前10000个高频词
imdb = keras.datasets.imdb
(train_texts, train_labels), (test_texts, test_labels) = imdb.load_data(
    path=data_dir + "imdb.npz",
    num_words=10000
)

数据结构分析

每条评论已预处理为整数序列,每个整数对应词典中的一个单词。标签0表示负面,1表示正面。

print(f"训练样本数: {len(train_texts)}, 标签数: {len(train_labels)}")
print(f"第一条评论长度: {len(train_texts[0])}, 第二条: {len(train_texts[1])}")
print(f"第一条评论整数序列:\n{train_texts[0]}")

整数到文本的逆向转换

word_index = imdb.get_word_index(data_dir + "imdb_word_index.json")
word_index = {k: (v + 3) for k, v in word_index.items()}
word_index["<PAD>"] = 0
word_index["<START>"] = 1
word_index["<UNK>"] = 2
word_index["<UNUSED>"] = 3

reverse_lookup = dict([(value, key) for (key, value) in word_index.items()])

def decode_review(encoded_text):
    """将整数序列还原为可读文本"""
    return ' '.join([reverse_lookup.get(i, '?') for i in encoded_text])

print("原始文本:", decode_review(train_texts[0]))

数据预处理

所有评论序列需要统一长度,通过填充操作实现。

max_len = 256
train_data = keras.preprocessing.sequence.pad_sequences(
    train_texts,
    value=word_index["<PAD>"],
    padding='post',
    maxlen=max_len
)
test_data = keras.preprocessing.sequence.pad_sequences(
    test_texts,
    value=word_index["<PAD>"],
    padding='post',
    maxlen=max_len
)
print(f"处理后第一条评论长度: {len(train_data[0])}")
print(f"填充后样本:\n{train_data[0]}")

模型构建

定义网络结构

vocab_size = 10000
model = keras.Sequential([
    keras.layers.Embedding(vocab_size, 16),           # 词嵌入层
    keras.layers.GlobalAveragePooling1D(),             # 全局平均池化
    keras.layers.Dense(16, activation='relu'),         # 全连接隐含层
    keras.layers.Dense(1, activation='sigmoid')        # 输出层,二分类
])
model.summary()

编译模型

model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

训练验证分离

# 从训练集中切出10000条作为验证集
val_features = train_data[:10000]
partial_train_features = train_data[10000:]
val_labels = train_labels[:10000]
partial_train_labels = train_labels[10000:]

模型训练

history = model.fit(
    partial_train_features,
    partial_train_labels,
    epochs=40,
    batch_size=512,
    validation_data=(val_features, val_labels),
    verbose=2
)

模型评估

test_loss, test_acc = model.evaluate(test_data, test_labels)
print(f"测试集损失: {test_loss:.4f}, 准确率: {test_acc:.4f}")

训练过程可视化

通过图形展示训练过程中的损失和准确率变化,便于诊断过拟合。

history_dict = history.history
train_loss = history_dict['loss']
val_loss = history_dict['val_loss']
train_acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']

epochs_range = range(1, len(train_acc) + 1)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_loss, 'bo', label='训练损失')
plt.plot(epochs_range, val_loss, 'b-', label='验证损失')
plt.title('损失变化曲线')
plt.xlabel('训练轮数')
plt.ylabel('损失值')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_acc, 'ro', label='训练准确率')
plt.plot(epochs_range, val_acc, 'r-', label='验证准确率')
plt.title('准确率变化曲线')
plt.xlabel('训练轮数')
plt.ylabel('准确率')
plt.legend()

plt.tight_layout()
plt.savefig("./imdb_training_curves.png", dpi=200)
plt.show()

常见问题处理

网络下载失败

问题:执行load_data()get_word_index()时出现下载超时。

解决方案:

  • 手动下载所需文件(imdb.npz和imdb_word_index.json)
  • 通过path参数指定本地文件路径
  • 或将文件放置于~/.keras/datasets/目录下

相关文章

Trojan服务器搭建与配置

一、整体架构(先对齐认知)Clash Meta (PC / iOS / Android)        ↓ TLS   Trojan Server (443)        ↓     InternetTrojan 的核心是: TLS + HTTPS 流量伪装 看起来像正常网站 非常适合...

Tailscale 的详细用法

Tailscale 是一种基于 WireGuard 协议 的 零配置 VPN(虚拟私有网络)服务,让设备之间能够 安全、加密地直接连接,就像它们在同一个本地网络一样。它的核心特点是 简单、安全、跨平台。Tailscale 非常适合 没有公网 IP、两台电脑不在同一局域网 的场景。 简单来说,Tailscale 是什么?Tailscale 是一款让你的各种设备(电脑、服务器、手机...

Clash Tun 模式 导致 爱快(iKuai SD-Wan)内网域名无法访问

一、Clash  DNS 配置dns:  enable: true  listen: 0.0.0.0:53  ipv6: true  enhanced-mode: redir-host  nameserver:    - 223.5.5.5    - 223.6.6.6iKuai 内网域名 ...

深入解析Node.js运行环境与异步I/O架构

深入解析Node.js运行环境与异步I/O架构

核心定义与价值Node.js本质上是一个JavaScript运行环境,而非编程语言或应用框架。它赋予了JavaScript脱离浏览器在服务端、命令行工具及网络应用中执行的能力。其核心意义在于:用单一语言打通前后端开发壁垒。基于事件驱动与非阻塞I/O的架构特性,Node.js在处理API网关、实时通信及微服务等I/O密集型场景时表现卓越,已成为现代后端工程的主流选择。浏览器沙箱限制1995年Java...

ADO.NET SQL参数化查询的最佳实践

在 ADO.NET 中执行 SQL 查询时,参数化查询是一种关键的安全措施和性能优化手段。它通过将 SQL 命令和用户提供的数据分开处理,有效防止了 SQL 注入攻击,并有助于数据库缓存执行计划。下面总结了几种常用的参数化查询方式。 1. 使用 SqlParameter 对象(推荐) 这是最推荐的参数化查询方式。通过显式创建 SqlParameter 对象,您可以精确控制参数的类...

基于ELK的日志集中化分析系统搭建

构建统一日志管理平台的必要性 在分布式架构中,各服务节点独立运行,日志分散存储于不同主机。传统通过命令行工具如grep、awk逐个检索日志的方式,在数据量庞大时效率极低,难以实现快速定位问题。为提升运维效率,需建立集中式日志处理体系,具备日志采集、传输、存储、分析与告警能力。 ELK技术栈核心组件解析 Elasticsearch:分布式搜索引擎,支持全文检索、实时数据分析和高可用集群部署,...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。