Transformer输入部分详解与代码实现(基于PyTorch)
在Transformer模型中,输入部分负责将文本序列转换为包含语义和位置信息的向量。本文详细讲解了词嵌入(Embedding) 和 位置编码(Positional Encoding) 的原理,并提供了基于PyTorch的实现示例。
一、输入部分的核心任务
Transformer模型没有递归结构,无法自然捕捉序列顺序。因此,输入部分需要完成以下两个关键任务:
- 语义映射:将词索引转换为连续的低维向量。
- 位置注入:手动加入位置信息以区分不同顺序的句子。
二、词嵌入(Embedding):从离散到连续的映射
2.1 原理简介
词嵌入通过一个可学习的矩阵,将每个词索引映射为固定维度的向量。为了平衡与位置编码的量级,通常会将嵌入结果乘以√嵌入维度。
2.2 代码实现
import torch
import torch.nn as nn
import math
class WordEmbedding(nn.Module):
def __init__(self, vocab_size, embed_dim):
super().__init__()
self.embed_layer = nn.Embedding(vocab_size, embed_dim)
def forward(self, input_ids):
embeddings = self.embed_layer(input_ids) * math.sqrt(embed_dim)
return embeddings
三、位置编码(Positional Encoding):注入时序信息
3.1 原理简介
位置编码使用正弦和余弦函数生成,公式如下:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中,pos是词的位置索引,i是向量维度索引。
3.2 代码实现
class PositionEncoder(nn.Module):
def __init__(self, embed_dim, max_len, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
四、完整流程与示例
4.1 完整流程
- 输入为批量词索引。
- 通过
WordEmbedding类进行词嵌入并缩放。 - 通过
PositionEncoder类叠加位置信息,输出最终输入向量。
4.2 示例代码
def test_pipeline():
input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
embedding_layer = WordEmbedding(vocab_size=1024, embed_dim=512)
pos_encoder = PositionEncoder(embed_dim=512, max_len=64)
embedded_input = embedding_layer(input_ids)
print(f"词嵌入后形状:{embedded_input.shape}")
encoded_input = pos_encoder(embedded_input)
print(f"位置编码后形状:{encoded_input.shape}")
if __name__ == '__main__':
test_pipeline()