融合GNN、LSTM与Transformer的时空预测模型设计与实现
模型架构设计
针对具有时空特性的序列数据(如交通流量、气象变化等),本文提出一种结合图神经网络(GNN)、长短期记忆网络(LSTM)和Transformer的混合建模方法。该结构通过分层协作机制,分别捕获空间依赖性、局部时序动态以及长期时间模式。
整体流程如下:首先利用GNN对节点间的拓扑关系进行建模,提取各时刻的空间特征;随后将这些特征序列输入至两个并行的时间编码模块——LSTM用于捕捉近期邻近时间步的变化趋势,而Transformer则借助自注意力机制挖掘跨时间段的全局关联;最终将两路输出拼接并通过全连接层生成预测结果。
核心组件解析
图神经网络(GNN)
GNN负责处理非欧几里得结构数据,在本模型中用于表达区域之间的空间影响。采用简化的图卷积层(GCN-style layer),其前向传播公式为:
H^{(l+1)} = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)})
其中 Ã 是添加自环后的邻接矩阵,D̃ 为其对应的度矩阵,H 表示节点特征,W 为可学习参数,σ 为激活函数。
LSTM 模块
LSTM擅长处理短周期时间依赖问题,尤其适用于波动频繁但趋势短暂的数据。其门控机制包括遗忘门 f_t、输入门 i_t 和输出门 o_t,细胞状态更新方式如下:
f_t = sigmoid(W_f @ [h_{t-1}, x_t] + b_f)
i_t = sigmoid(W_i @ [h_{t-1}, x_t] + b_i)
C_tilde = tanh(W_C @ [h_{t-1}, x_t] + b_C)
C_t = f_t * C_{t-1} + i_t * C_tilde
o_t = sigmoid(W_o @ [h_{t-1}, x_t] + b_o)
h_t = o_t * tanh(C_t)
Transformer 编码器
为捕捉长时间跨度下的潜在规律(例如每日通勤高峰),引入单层Transformer编码器。关键在于缩放点积注意力(Scaled Dot-Product Attention):
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
查询(Q)、键(K)、值(V)均由输入线性变换得到,dk 为键向量维度,用于稳定梯度。
融合策略与训练方法
在每个样本上,GNN先对每一时间步的原始观测执行图卷积操作,形成时空特征张量。然后将其重塑为标准序列格式,并送入LSTM和Transformer分支。
设LSTM最后一时间步的隐藏状态为 hlstm,Transformer最后位置的输出为 htrans,融合过程表示为:
y = W_c [h^{lstm}; h^{trans}] + b_c
损失函数选用均方误差(MSE):
\mathcal{L} = \frac{1}{N}\sum_{i=1}^N (y_i - \hat{y}_i)^2
使用Adam优化器进行端到端反向传播训练。
交通流预测实例实现
以下基于PyTorch构建一个轻量级实验模型,模拟五节点路网下的多变量时间序列预测任务。
import torch
import torch.nn as nn
class GCNLayer(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim)
def forward(self, x, adj):
# x: (B, T, N, F), adj: (N, N)
B, T, N, _ = x.shape
x = x.view(B * T, N, -1)
z = self.proj(x)
z = torch.matmul(adj, z)
return torch.relu(z).view(B, T, N, -1)
class SpatioTemporalModel(nn.Module):
def __init__(self, nodes, feat_in, hidden, feat_out):
super().__init__()
self.nodes = nodes
self.gcn = GCNLayer(feat_in, hidden)
self.lstm = nn.LSTM(hidden * nodes, hidden, batch_first=True)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden, nhead=2, dim_feedforward=64, batch_first=True
)
self.trans_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
self.predictor = nn.Linear(hidden * 2, feat_out)
self.reducer = nn.Linear(hidden * nodes, hidden)
def forward(self, seq, graph):
# 提取空间特征
g_feat = self.gcn(seq, graph) # [B, T, N, H]
B, T, _, H = g_feat.shape
flat = g_feat.view(B, T, -1) # [B, T, N*H]
reduced = self.reducer(flat) # [B, T, H]
# LSTM路径
lstm_out, (h_n, _) = self.lstm(reduced)
short_term = h_n[-1] # 最后层隐状态
# Transformer路径
trans_out = self.trans_encoder(reduced)
long_term = trans_out[:, -1, :] # 取末位表征
# 特征融合
fused = torch.cat([short_term, long_term], dim=-1)
return self.predictor(fused)
仿真训练流程
# 参数设置
N, F, H, O = 5, 1, 16, 1
T, B = 10, 8
# 构造归一化邻接矩阵
adj_mx = (torch.rand(N, N) > 0.7).float()
deg = adj_mx.sum(dim=1, keepdim=True) + 1e-5
adj_mx /= deg
# 随机生成数据
X = torch.randn(B, T, N, F)
Y = torch.randn(B, O)
# 初始化模型与优化器
model = SpatioTemporalModel(N, F, H, O)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 训练循环
for epoch in range(50):
optimizer.zero_grad()
pred = model(X, adj_mx)
loss = criterion(pred, Y)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
结果显示模型能够有效收敛,表明所设计结构具备基本的时空联合建模能力。
拓展方向
- 替换GCN为GAT(图注意力网络),使模型能自动加权邻居重要性。
- 增加时空残差连接与层级池化,提升深层网络稳定性。
- 接入真实数据集(如METR-LA或PEMSD4),并引入外部变量(天气、节假日)增强泛化性。