使用BERT进行文本多分类的实现
本教程将演示如何使用BERT模型构建一个多分类器,以对文本数据进行分类。
数据集示例
我们将使用一个包含文本和对应标签的CSV文件。以下是一个示例数据片段:
label,text
2,真的很开心啊!!!!!!!
7,咳咳。。。。
1,//@陈宝存:回复@梅海东messi:这种非守法的公民存在,足见我们法制建设的艰难,你懂法吗?
7,也许有一天,你突然醒来,发现自己还在十几岁的年纪,年少时喜欢的男生就在你面前,看着你温暖地笑,说你是个傻瓜,告诉你你从没有失恋过,告诉你是他的唯一,告诉你后来的一切都没发生过,告诉你他其实一直爱你。
7,独栋别墅跟农家乐的区别就是:内在装修好,各种设施齐备,别的嘛——完全没区别!
7,置身其中,一周的烦躁无影踪。
7,高考都考这么多年了,就不应该搞个周年店庆么,考400送350,一本分数线7折,考三本送二本 体验券![偷笑]
7,到纳斯达克上市,我感觉自己更像一个旅行者,走到了这一站但这个地方不是我的家,我只是到这里来,证明自己做了一些想做的事情
7,六一儿童节到了。
7,不要这样,也不要总觉得自己总缺什么,只要快乐,你就什么都不缺。
7,问:我是已婚mm 在沪有套小房,想换大房把小房送父母,过户费太高,问怎么减免?
7,然后,前几天有人在医院看到她去做产检,她要生孩子了,是二胎。
7,我是被后面的人和前面的保安和人各种推挤压挡遮......
7,天呢!
7,昨天上午老程踩了一窖,为了剧情需要,晚上胡可又光脚踩了半天。
7,在我女儿七岁生日时,我要送给他一本日记本,有三张扉页,每张上我都会认真写上一行字。
7,男生之间没有耍心机他们不爽都直接开口大不了打一架 他们不需要你有多贴心只需要在特定的是给他一个手势拍拍他的肩膀 他们不开心也会哭会发泄会去打球会喝酒 其实和男生交朋友谈心更是另一种收获 该珍惜!
1,因为尼玛这本书讲述的是从耗子变成客运总裁!!!!!!!!
3,东海啊[泪][泪]
数据加载模块
我们定义一个自定义数据集类,用于加载和预处理CSV文件。
from torch.utils.data import Dataset
from datasets import load_dataset
class TextClassificationDataset(Dataset):
def __init__(self, split_name):
"""
初始化数据集。
Args:
split_name (str): 数据集的分割名称 (e.g., 'train', 'validation', 'test')。
"""
# 从指定路径加载CSV格式的数据集
self.data_split = load_dataset(
path="csv",
data_files=f"data/Weibo/{split_name}.csv",
split="train"
)
def __len__(self):
"""返回数据集的大小。"""
return len(self.data_split)
def __getitem__(self, index):
"""
获取指定索引的数据项。
Args:
index (int): 数据项的索引。
Returns:
tuple: 包含文本和标签的元组。
"""
text_content = self.data_split[index]["text"]
classification_label = self.data_split[index]["label"]
return text_content, classification_label
if __name__ == '__main__':
# 示例:加载测试集并打印前几项
sample_dataset = TextClassificationDataset("test")
for i in range(min(3, len(sample_dataset))):
print(sample_dataset[i])
模型定义
该模型利用预训练的BERT作为特征提取器,并在其之上添加一个全连接层用于分类。
from transformers import BertModel
import torch
# 确定设备 (GPU优先,否则使用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练的BERT模型 (中文基础版)
# 请确保指定的模型路径是正确的
bert_base_model = BertModel.from_pretrained(
r"D:\PycharmProjects\demo_15_01\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
).to(device)
# 定义下游任务模型:一个用于文本分类的模型
class TextClassifier(torch.nn.Module):
def __init__(self, num_classes=8):
"""
初始化分类器。
Args:
num_classes (int): 分类的类别数量。
"""
super().__init__()
# 添加一个全连接层,将BERT的输出(768维)映射到类别数量
self.classifier_layer = torch.nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask, token_type_ids):
"""
前向传播。
Args:
input_ids (torch.Tensor): 输入的token ID序列。
attention_mask (torch.Tensor): 注意力掩码。
token_type_ids (torch.Tensor): token类型ID。
Returns:
torch.Tensor: 分类概率。
"""
# 在计算梯度时冻结BERT模型的参数,以防止其被更新
with torch.no_grad():
bert_output = bert_base_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
# 取[CLS] token的隐藏状态作为句子的表示,并传入分类层
# last_hidden_state[:, 0] 对应的是 [CLS] token的输出
sentence_representation = bert_output.last_hidden_state[:, 0]
logits = self.classifier_layer(sentence_representation)
# 使用softmax将输出转换为概率
probabilities = torch.softmax(logits, dim=1)
return probabilities
if __name__ == '__main__':
# 示例:打印模型结构
classifier = TextClassifier()
print(classifier)
训练流程
此部分代码包含了模型的训练和验证过程,包括数据加载、分词、模型前向传播、损失计算、反向传播及优化。
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, BertTokenizer
# 假设 MyDataset 和 Model 类已在 MyData.py 和 net.py 文件中定义
# from MyData import MyDataset
# from net import Model
# 如果在同一文件中,则直接使用
# from __main__ import TextClassificationDataset, TextClassifier # 假设类名已修改
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 30000 # 训练轮数
# 加载BERT分词器
# 请确保模型路径是正确的
tokenizer = BertTokenizer.from_pretrained(
r"D:\PycharmProjects\demo_15_01\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
)
def create_data_collator(batch_data):
"""
为DataLoader创建数据整理函数,将文本批量编码并转换为Tensor。
Args:
batch_data (list): DataLoader提供的批次数据,包含(text, label)元组。
Returns:
tuple: 包含处理后的input_ids, attention_mask, token_type_ids, 和 labels。
"""
texts = [item[0] for item in batch_data]
labels = [item[1] for item in batch_data]
# 使用tokenizer对文本进行批量编码
encoded_inputs = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=texts,
padding="max_length", # 填充到最大长度
truncation=True, # 截断超过最大长度的文本
max_length=512, # 最大序列长度
return_tensors="pt", # 返回PyTorch张量
return_length=True # 返回序列实际长度(可选)
)
input_ids = encoded_inputs["input_ids"]
attention_mask = encoded_inputs["attention_mask"]
token_type_ids = encoded_inputs["token_type_ids"]
# 转换为PyTorch LongTensor
target_labels = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, target_labels
# 创建训练集和验证集实例
train_dataset = TextClassificationDataset("train")
validation_dataset = TextClassificationDataset("validation")
# 创建训练DataLoader和验证DataLoader
train_loader = DataLoader(
dataset=train_dataset,
batch_size=100,
shuffle=True, # 训练时打乱数据
drop_last=True, # 丢弃最后一个不完整的batch
collate_fn=create_data_collator
)
validation_loader = DataLoader(
dataset=validation_dataset,
batch_size=50,
shuffle=False, # 验证时不需要打乱
drop_last=True,
collate_fn=create_data_collator
)
if __name__ == '__main__':
print(f"Using device: {DEVICE}")
# 初始化模型、优化器和损失函数
model = TextClassifier().to(DEVICE)
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# 开始训练循环
for epoch in range(NUM_EPOCHS):
model.train() # 设置模型为训练模式
total_train_loss = 0
correct_predictions = 0
total_samples = 0
# 遍历训练DataLoader
for batch_idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
# 将数据移至指定设备
input_ids, attention_mask, token_type_ids, labels = (
input_ids.to(DEVICE), attention_mask.to(DEVICE),
token_type_ids.to(DEVICE), labels.to(DEVICE)
)
# 前向传播
outputs = model(input_ids, attention_mask, token_type_ids)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录训练损失和准确率
total_train_loss += loss.item()
predicted_labels = torch.argmax(outputs, dim=1)
correct_predictions += (predicted_labels == labels).sum().item()
total_samples += len(labels)
if batch_idx % 5 == 0:
accuracy = correct_predictions / total_samples
print(f"Epoch: {epoch+1}/{NUM_EPOCHS}, Batch: {batch_idx+1}/{len(train_loader)}, "
f"Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")
# 验证阶段
model.eval() # 设置模型为评估模式
total_val_loss = 0
total_val_correct = 0
total_val_samples = 0
with torch.no_grad():
for batch_idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(validation_loader):
input_ids, attention_mask, token_type_ids, labels = (
input_ids.to(DEVICE), attention_mask.to(DEVICE),
token_type_ids.to(DEVICE), labels.to(DEVICE)
)
val_outputs = model(input_ids, attention_mask, token_type_ids)
val_loss = criterion(val_outputs, labels)
total_val_loss += val_loss.item()
val_predicted = torch.argmax(val_outputs, dim=1)
total_val_correct += (val_predicted == labels).sum().item()
total_val_samples += len(labels)
avg_val_loss = total_val_loss / len(validation_loader)
avg_val_acc = total_val_correct / total_val_samples
print(f"Validation Result - Epoch: {epoch+1}, Average Loss: {avg_val_loss:.4f}, Average Accuracy: {avg_val_acc:.4f}")
# 保存模型参数
# 可以根据验证集上的表现来决定是否保存和更新最佳模型
model_save_path = f"params/{epoch+1}_bert_classifier.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Epoch {epoch+1}: Model parameters saved to {model_save_path}")
模型测试
本节展示如何加载训练好的模型并对用户输入的文本进行实时分类预测。
import torch
from transformers import BertTokenizer
# 假设 Model 类已在 net.py 文件中定义
# from net import TextClassifier # 假设类名已修改
# 如果在同一文件中,则直接使用
# from __main__ import TextClassificationDataset, TextClassifier # 假设类名已修改
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义类别名称,需要与训练时的类别数和顺序对应
CLASS_NAMES = [
"like", "disgust", "happiness", "sadness",
"anger", "surprise", "fear", "none"
]
print(f"Using device: {DEVICE}")
# 初始化模型并加载预训练权重
model = TextClassifier(num_classes=len(CLASS_NAMES)).to(DEVICE)
# 加载BERT分词器
# 请确保模型路径是正确的
tokenizer = BertTokenizer.from_pretrained(
r"D:\PycharmProjects\demo_15_01\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
)
def preprocess_input_text(text_input):
"""
对单个输入的文本进行编码处理。
Args:
text_input (str): 用户输入的文本。
Returns:
tuple: 包含处理后的input_ids, attention_mask, token_type_ids。
"""
encoded_data = tokenizer.encode_plus(
text=text_input,
add_special_tokens=True, # 添加[CLS]和[SEP]
max_length=500, # 最大序列长度
padding="max_length", # 填充到最大长度
truncation=True, # 截断
return_tensors="pt" # 返回PyTorch张量
)
input_ids = encoded_data["input_ids"]
attention_mask = encoded_data["attention_mask"]
token_type_ids = encoded_data["token_type_ids"]
return input_ids, attention_mask, token_type_ids
def perform_prediction():
"""
加载模型并进行交互式预测。
"""
# 加载训练好的模型参数
try:
model.load_state_dict(torch.load("params/2_bert_classifier.pth", map_location=DEVICE)) # 示例加载第2个epoch的模型
model.eval() # 设置为评估模式
print("Model loaded successfully.")
except FileNotFoundError:
print("Error: Model parameters file not found. Please ensure the file exists.")
return
except Exception as e:
print(f"Error loading model parameters: {e}")
return
print("\nEnter text for classification. Type 'quit' to exit.")
while True:
user_input = input("Enter text: ")
if user_input.lower() == "quit":
print("Exiting prediction.")
break
# 预处理用户输入
input_ids, attention_mask, token_type_ids = preprocess_input_text(user_input)
# 将数据移至设备
input_ids, attention_mask, token_type_ids = (
input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE)
)
# 进行预测
with torch.no_grad():
predictions = model(input_ids, attention_mask, token_type_ids)
# 获取最可能的类别索引
predicted_class_index = torch.argmax(predictions, dim=1).item()
# 输出预测结果
predicted_class_name = CLASS_NAMES[predicted_class_index]
print(f"Model Prediction: {predicted_class_name}\n")
if __name__ == '__main__':
perform_prediction()