PyTorch自动微分机制解析
自动微分核心组件
PyTorch的自动微分系统基于Tensor和Function两类对象构建有向无环图(DAG),采用动态图机制在每次前向传播时重建计算图。
Tensor关键属性
- requires_grad:控制梯度计算开关
- 新建Tensor时指定,默认False
- 设为True时依赖节点自动继承
- 使用
detach()或torch.no_grad()可禁用梯度跟踪
- grad_fn:记录创建该Tensor的运算函数,叶子节点为None
- grad:存储梯度值的缓存,非叶子节点在反向传播后自动释放
反向传播函数
backward()接收与调用Tensor同维度的参数(标量可省略),计算梯度并累加至grad属性。设置retain_graph=True可保留计算图进行多次反向传播。
标量反向传播示例
import torch
# 初始化输入和参数
input_val = torch.tensor([3.0])
weight = torch.randn(1, requires_grad=True)
bias = torch.randn(1, requires_grad=True)
# 前向计算
intermediate = weight * input_val
output = intermediate + bias
# 属性验证
print(f"叶子节点grad状态: {input_val.requires_grad}, {weight.requires_grad}, {bias.requires_grad}")
print(f"运算节点grad状态: {intermediate.requires_grad}, {output.requires_grad}")
print(f"梯度函数: {intermediate.grad_fn}, {output.grad_fn}")
# 反向传播
output.backward()
print(f"参数梯度: {weight.grad}, {bias.grad}")
非标量反向传播
处理非标量输出时需提供gradient参数进行向量-雅可比乘积:
x = torch.tensor([[1.0, 2.0]], requires_grad=True)
J_matrix = torch.zeros(2, 2)
# 定义多元函数
y = torch.empty(1, 2)
y[0, 0] = x[0, 0]**2 + 4*x[0, 1]
y[0, 1] = x[0, 1]**3 + 2*x[0, 0]
# 计算雅可比矩阵
y.backward(torch.tensor([[1.0, 0.0]]), retain_graph=True)
J_matrix[0] = x.grad
x.grad.zero_()
y.backward(torch.tensor([[0.0, 1.0]]))
J_matrix[1] = x.grad
print(J_matrix)
回归任务实战
import torch
import matplotlib.pyplot as plt
# 生成数据集
torch.manual_seed(42)
features = torch.unsqueeze(torch.linspace(-2, 2, 200), 1)
target = 2.5 * features**2 + 1.8 + 0.2*torch.randn(features.size())
# 初始化模型参数
coefficient = torch.randn(1, 1, requires_grad=True)
intercept = torch.randn(1, 1, requires_grad=True)
# 训练配置
learning_rate = 0.005
epochs = 1000
# 训练循环
for _ in range(epochs):
predictions = features**2 @ coefficient + intercept
loss = (0.5 * (predictions - target)**2).sum()
loss.backward()
with torch.no_grad():
coefficient -= learning_rate * coefficient.grad
intercept -= learning_rate * intercept.grad
coefficient.grad.zero_()
intercept.grad.zero_()
# 结果可视化
plt.scatter(features.numpy(), target.numpy())
plt.plot(features.numpy(), predictions.detach().numpy(), 'r-')
plt.show()