深度学习模型维度不匹配问题排查
在使用大语言模型进行微调时,你可能会遇到这样的错误提示:RuntimeError: expand(torch.cuda.FloatTensor{[2, 2, 1, 1, 512]}, size=[2, 1, 1, 512]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (5)
这个错误表明模型在前向传播过程中检测到了张量维度的不一致。具体来看,问题出现在调用模型主体时的维度处理上,通常是由于输入数据格式与模型期望的格式存在差异导致的。
常见原因分析
- 输入张量维度异常:
- 核验输入数据(
input_ids、attention_mask、labels)的维度是否符合模型要求 - 检查数据预处理函数返回的结果是否满足预期的张量形状
- 标签张量多维问题:
- 确认标签张量的维度与输入标识符一致,避免出现多余的维度
- 批处理组装错误:
- 验证数据整理器是否正确地将独立样本合并为批量数据
- 模型参数配置不一致:
- 确保模型配置中的相关参数与预处理阶段的设置保持统一
排查方法
- 输出张量形状信息:
- 在数据预处理函数中加入形状打印语句,监控数据维度变化
def prepare_training_data(sample):
# 验证必要字段是否完整且类型正确
required_fields = ['question_context', 'question_body', 'correct_response',
'wrong_response', 'knowledge_misconception']
if not all(field in sample and isinstance(sample[field], (str, list))
for field in required_fields):
raise ValueError("必要的字段缺失或类型不正确")
# 处理列表类型的字段,转换为字符串
def convert_to_string(value):
return ' '.join(value) if isinstance(value, list) else value
question_context = convert_to_string(sample['question_context'])
question_body = convert_to_string(sample['question_body'])
correct_response = convert_to_string(sample['correct_response'])
wrong_response = convert_to_string(sample['wrong_response'])
knowledge_misconception = convert_to_string(sample['knowledge_misconception'])
# 构建输入文本和目标文本
input_sequence = (question_context + "\n" + question_body + "\n" +
correct_response + "\n" + wrong_response)
target_sequence = knowledge_misconception
# 对文本进行编码处理
encoded_inputs = tokenizer(input_sequence, max_length=512,
truncation=True, padding="max_length",
return_tensors="pt")
encoded_targets = tokenizer(target_sequence, max_length=512,
truncation=True, padding="max_length",
return_tensors="pt").input_ids
# 将填充部分的标签设置为-100,使其在损失计算时被忽略
encoded_targets[encoded_targets == tokenizer.pad_token_id] = -100
# 输出形状信息用于调试
print(f"编码输入形状: {encoded_inputs['input_ids'].shape}")
print(f"目标标签形状: {encoded_targets.shape}")
# 将标签添加到输入字典中,并移除多余的维度
encoded_inputs["labels"] = encoded_targets.squeeze()
return encoded_inputs
- 检查数据整理器的输出:
- 在数据整理器中添加调试信息,观察批次数据的实际维度
from transformers import DataCollatorForSeq2Seq
class DebugDataCollator(DataCollatorForSeq2Seq):
def __call__(self, features):
batch = super().__call__(features)
# 打印批次维度信息
print(f"批次 input_ids 形状: {batch['input_ids'].shape}")
print(f"批次 attention_mask 形状: {batch['attention_mask'].shape}")
print(f"批次 labels 形状: {batch['labels'].shape}")
return batch
通过以上调试步骤,你可以定位到具体是哪个环节出现了维度不匹配的问题。常见的解决方案包括使用squeeze()方法移除多余维度,或者调整数据预处理流程确保张量形状正确。