当前位置:首页 > 技术 > 正文内容

使用BERT进行文本多分类的实现

访客 技术 2026年6月7日 1

本教程将演示如何使用BERT模型构建一个多分类器,以对文本数据进行分类。

数据集示例

我们将使用一个包含文本和对应标签的CSV文件。以下是一个示例数据片段:

label,text
2,真的很开心啊!!!!!!!
7,咳咳。。。。
1,//@陈宝存:回复@梅海东messi:这种非守法的公民存在,足见我们法制建设的艰难,你懂法吗?
7,也许有一天,你突然醒来,发现自己还在十几岁的年纪,年少时喜欢的男生就在你面前,看着你温暖地笑,说你是个傻瓜,告诉你你从没有失恋过,告诉你是他的唯一,告诉你后来的一切都没发生过,告诉你他其实一直爱你。
7,独栋别墅跟农家乐的区别就是:内在装修好,各种设施齐备,别的嘛——完全没区别!
7,置身其中,一周的烦躁无影踪。
7,高考都考这么多年了,就不应该搞个周年店庆么,考400送350,一本分数线7折,考三本送二本 体验券![偷笑]
7,到纳斯达克上市,我感觉自己更像一个旅行者,走到了这一站但这个地方不是我的家,我只是到这里来,证明自己做了一些想做的事情
7,六一儿童节到了。
7,不要这样,也不要总觉得自己总缺什么,只要快乐,你就什么都不缺。
7,问:我是已婚mm 在沪有套小房,想换大房把小房送父母,过户费太高,问怎么减免?
7,然后,前几天有人在医院看到她去做产检,她要生孩子了,是二胎。
7,我是被后面的人和前面的保安和人各种推挤压挡遮......
7,天呢!
7,昨天上午老程踩了一窖,为了剧情需要,晚上胡可又光脚踩了半天。
7,在我女儿七岁生日时,我要送给他一本日记本,有三张扉页,每张上我都会认真写上一行字。
7,男生之间没有耍心机他们不爽都直接开口大不了打一架 他们不需要你有多贴心只需要在特定的是给他一个手势拍拍他的肩膀 他们不开心也会哭会发泄会去打球会喝酒 其实和男生交朋友谈心更是另一种收获 该珍惜!
1,因为尼玛这本书讲述的是从耗子变成客运总裁!!!!!!!!
3,东海啊[泪][泪]

数据加载模块

我们定义一个自定义数据集类,用于加载和预处理CSV文件。


from torch.utils.data import Dataset
from datasets import load_dataset

class TextClassificationDataset(Dataset):
    def __init__(self, split_name):
        """
        初始化数据集。
        Args:
            split_name (str): 数据集的分割名称 (e.g., 'train', 'validation', 'test')。
        """
        # 从指定路径加载CSV格式的数据集
        self.data_split = load_dataset(
            path="csv",
            data_files=f"data/Weibo/{split_name}.csv",
            split="train"
        )

    def __len__(self):
        """返回数据集的大小。"""
        return len(self.data_split)

    def __getitem__(self, index):
        """
        获取指定索引的数据项。
        Args:
            index (int): 数据项的索引。
        Returns:
            tuple: 包含文本和标签的元组。
        """
        text_content = self.data_split[index]["text"]
        classification_label = self.data_split[index]["label"]
        return text_content, classification_label

if __name__ == '__main__':
    # 示例:加载测试集并打印前几项
    sample_dataset = TextClassificationDataset("test")
    for i in range(min(3, len(sample_dataset))):
        print(sample_dataset[i])

模型定义

该模型利用预训练的BERT作为特征提取器,并在其之上添加一个全连接层用于分类。


from transformers import BertModel
import torch

# 确定设备 (GPU优先,否则使用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载预训练的BERT模型 (中文基础版)
# 请确保指定的模型路径是正确的
bert_base_model = BertModel.from_pretrained(
    r"D:\PycharmProjects\demo_15_01\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
).to(device)

# 定义下游任务模型:一个用于文本分类的模型
class TextClassifier(torch.nn.Module):
    def __init__(self, num_classes=8):
        """
        初始化分类器。
        Args:
            num_classes (int): 分类的类别数量。
        """
        super().__init__()
        # 添加一个全连接层,将BERT的输出(768维)映射到类别数量
        self.classifier_layer = torch.nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        """
        前向传播。
        Args:
            input_ids (torch.Tensor): 输入的token ID序列。
            attention_mask (torch.Tensor): 注意力掩码。
            token_type_ids (torch.Tensor): token类型ID。
        Returns:
            torch.Tensor: 分类概率。
        """
        # 在计算梯度时冻结BERT模型的参数,以防止其被更新
        with torch.no_grad():
            bert_output = bert_base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
        # 取[CLS] token的隐藏状态作为句子的表示,并传入分类层
        # last_hidden_state[:, 0] 对应的是 [CLS] token的输出
        sentence_representation = bert_output.last_hidden_state[:, 0]
        logits = self.classifier_layer(sentence_representation)
        # 使用softmax将输出转换为概率
        probabilities = torch.softmax(logits, dim=1)
        return probabilities

if __name__ == '__main__':
    # 示例:打印模型结构
    classifier = TextClassifier()
    print(classifier)

训练流程

此部分代码包含了模型的训练和验证过程,包括数据加载、分词、模型前向传播、损失计算、反向传播及优化。


import torch
from torch.utils.data import DataLoader
from transformers import AdamW, BertTokenizer
# 假设 MyDataset 和 Model 类已在 MyData.py 和 net.py 文件中定义
# from MyData import MyDataset
# from net import Model

# 如果在同一文件中,则直接使用
# from __main__ import TextClassificationDataset, TextClassifier # 假设类名已修改

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 30000 # 训练轮数

# 加载BERT分词器
# 请确保模型路径是正确的
tokenizer = BertTokenizer.from_pretrained(
    r"D:\PycharmProjects\demo_15_01\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
)

def create_data_collator(batch_data):
    """
    为DataLoader创建数据整理函数,将文本批量编码并转换为Tensor。
    Args:
        batch_data (list): DataLoader提供的批次数据,包含(text, label)元组。
    Returns:
        tuple: 包含处理后的input_ids, attention_mask, token_type_ids, 和 labels。
    """
    texts = [item[0] for item in batch_data]
    labels = [item[1] for item in batch_data]

    # 使用tokenizer对文本进行批量编码
    encoded_inputs = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=texts,
        padding="max_length",       # 填充到最大长度
        truncation=True,            # 截断超过最大长度的文本
        max_length=512,             # 最大序列长度
        return_tensors="pt",        # 返回PyTorch张量
        return_length=True          # 返回序列实际长度(可选)
    )

    input_ids = encoded_inputs["input_ids"]
    attention_mask = encoded_inputs["attention_mask"]
    token_type_ids = encoded_inputs["token_type_ids"]
    # 转换为PyTorch LongTensor
    target_labels = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, target_labels

# 创建训练集和验证集实例
train_dataset = TextClassificationDataset("train")
validation_dataset = TextClassificationDataset("validation")

# 创建训练DataLoader和验证DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=100,
    shuffle=True,           # 训练时打乱数据
    drop_last=True,         # 丢弃最后一个不完整的batch
    collate_fn=create_data_collator
)
validation_loader = DataLoader(
    dataset=validation_dataset,
    batch_size=50,
    shuffle=False,          # 验证时不需要打乱
    drop_last=True,
    collate_fn=create_data_collator
)

if __name__ == '__main__':
    print(f"Using device: {DEVICE}")
    # 初始化模型、优化器和损失函数
    model = TextClassifier().to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    # 开始训练循环
    for epoch in range(NUM_EPOCHS):
        model.train()  # 设置模型为训练模式
        total_train_loss = 0
        correct_predictions = 0
        total_samples = 0

        # 遍历训练DataLoader
        for batch_idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
            # 将数据移至指定设备
            input_ids, attention_mask, token_type_ids, labels = (
                input_ids.to(DEVICE), attention_mask.to(DEVICE),
                token_type_ids.to(DEVICE), labels.to(DEVICE)
            )

            # 前向传播
            outputs = model(input_ids, attention_mask, token_type_ids)
            loss = criterion(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录训练损失和准确率
            total_train_loss += loss.item()
            predicted_labels = torch.argmax(outputs, dim=1)
            correct_predictions += (predicted_labels == labels).sum().item()
            total_samples += len(labels)

            if batch_idx % 5 == 0:
                accuracy = correct_predictions / total_samples
                print(f"Epoch: {epoch+1}/{NUM_EPOCHS}, Batch: {batch_idx+1}/{len(train_loader)}, "
                      f"Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")

        # 验证阶段
        model.eval()  # 设置模型为评估模式
        total_val_loss = 0
        total_val_correct = 0
        total_val_samples = 0

        with torch.no_grad():
            for batch_idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(validation_loader):
                input_ids, attention_mask, token_type_ids, labels = (
                    input_ids.to(DEVICE), attention_mask.to(DEVICE),
                    token_type_ids.to(DEVICE), labels.to(DEVICE)
                )

                val_outputs = model(input_ids, attention_mask, token_type_ids)
                val_loss = criterion(val_outputs, labels)

                total_val_loss += val_loss.item()
                val_predicted = torch.argmax(val_outputs, dim=1)
                total_val_correct += (val_predicted == labels).sum().item()
                total_val_samples += len(labels)

        avg_val_loss = total_val_loss / len(validation_loader)
        avg_val_acc = total_val_correct / total_val_samples
        print(f"Validation Result - Epoch: {epoch+1}, Average Loss: {avg_val_loss:.4f}, Average Accuracy: {avg_val_acc:.4f}")

        # 保存模型参数
        # 可以根据验证集上的表现来决定是否保存和更新最佳模型
        model_save_path = f"params/{epoch+1}_bert_classifier.pth"
        torch.save(model.state_dict(), model_save_path)
        print(f"Epoch {epoch+1}: Model parameters saved to {model_save_path}")

模型测试

本节展示如何加载训练好的模型并对用户输入的文本进行实时分类预测。


import torch
from transformers import BertTokenizer
# 假设 Model 类已在 net.py 文件中定义
# from net import TextClassifier # 假设类名已修改

# 如果在同一文件中,则直接使用
# from __main__ import TextClassificationDataset, TextClassifier # 假设类名已修改

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义类别名称,需要与训练时的类别数和顺序对应
CLASS_NAMES = [
    "like", "disgust", "happiness", "sadness",
    "anger", "surprise", "fear", "none"
]
print(f"Using device: {DEVICE}")

# 初始化模型并加载预训练权重
model = TextClassifier(num_classes=len(CLASS_NAMES)).to(DEVICE)

# 加载BERT分词器
# 请确保模型路径是正确的
tokenizer = BertTokenizer.from_pretrained(
    r"D:\PycharmProjects\demo_15_01\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
)

def preprocess_input_text(text_input):
    """
    对单个输入的文本进行编码处理。
    Args:
        text_input (str): 用户输入的文本。
    Returns:
        tuple: 包含处理后的input_ids, attention_mask, token_type_ids。
    """
    encoded_data = tokenizer.encode_plus(
        text=text_input,
        add_special_tokens=True,    # 添加[CLS]和[SEP]
        max_length=500,             # 最大序列长度
        padding="max_length",       # 填充到最大长度
        truncation=True,            # 截断
        return_tensors="pt"         # 返回PyTorch张量
    )
    input_ids = encoded_data["input_ids"]
    attention_mask = encoded_data["attention_mask"]
    token_type_ids = encoded_data["token_type_ids"]
    return input_ids, attention_mask, token_type_ids

def perform_prediction():
    """
    加载模型并进行交互式预测。
    """
    # 加载训练好的模型参数
    try:
        model.load_state_dict(torch.load("params/2_bert_classifier.pth", map_location=DEVICE)) # 示例加载第2个epoch的模型
        model.eval() # 设置为评估模式
        print("Model loaded successfully.")
    except FileNotFoundError:
        print("Error: Model parameters file not found. Please ensure the file exists.")
        return
    except Exception as e:
        print(f"Error loading model parameters: {e}")
        return

    print("\nEnter text for classification. Type 'quit' to exit.")
    while True:
        user_input = input("Enter text: ")
        if user_input.lower() == "quit":
            print("Exiting prediction.")
            break

        # 预处理用户输入
        input_ids, attention_mask, token_type_ids = preprocess_input_text(user_input)

        # 将数据移至设备
        input_ids, attention_mask, token_type_ids = (
            input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE)
        )

        # 进行预测
        with torch.no_grad():
            predictions = model(input_ids, attention_mask, token_type_ids)
            # 获取最可能的类别索引
            predicted_class_index = torch.argmax(predictions, dim=1).item()

        # 输出预测结果
        predicted_class_name = CLASS_NAMES[predicted_class_index]
        print(f"Model Prediction: {predicted_class_name}\n")

if __name__ == '__main__':
    perform_prediction()

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。