当前位置:首页 > 技术 > 正文内容

PyTorch优化器详解:从参数更新到学习率调度

访客 技术 2026年6月22日 10

优化器在深度学习中的作用

在神经网络训练流程中,前向传播计算输出、反向传播生成梯度之后,最关键的步骤是利用优化器对模型参数进行更新。优化器的作用类似于导航系统,依据梯度信息决定权重调整的方向与幅度,从而最小化损失函数。

优化器的基本操作流程

每个训练迭代周期内,优化器需执行三个核心操作:

  1. 梯度归零(zero_grad):清除上一轮迭代累积的梯度值,防止梯度叠加导致数值异常。
  2. 反向传播(backward):基于当前批次数据的损失值,自动计算各可训练参数的梯度。
  3. 参数更新(step):按照所选优化算法(如SGD或Adam),使用梯度更新网络权重。

完整训练示例:CIFAR-10图像分类任务

以下代码展示了如何结合 DataLoader、卷积神经网络和 SGD 优化器实现一个标准训练循环:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader

# 数据加载
transform = torchvision.transforms.ToTensor()
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# 定义模型结构
class CNNClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

model = CNNClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练主循环
for epoch in range(20):
    total_loss = 0.0
    for images, labels in loader:
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f'Epoch [{epoch+1}/20], Loss: {total_loss:.4f}')

动态调整学习率:引入调度器机制

固定学习率可能无法满足整个训练过程的需求。通常希望初期快速收敛,后期精细微调。为此,PyTorch 提供了学习率调度器(LR Scheduler)来动态调节学习率。

例如,使用 StepLR 每隔若干轮将学习率乘以衰减因子:

# 接续上述代码
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

for epoch in range(20):
    for images, labels in loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    scheduler.step()  # 每轮结束后更新学习率
    print(f'Current LR: {scheduler.get_last_lr()[0]:.6f}')

该策略可在训练前期保持较高学习率以加快收敛,在后期逐步降低学习率以提升模型稳定性。

训练流程全景图

至此,一个完整的深度学习训练闭环已经建立:

  • 数据准备:通过 Dataset 和 DataLoader 管理输入样本
  • 模型构建:使用 nn.Module 组织网络层
  • 误差评估:选择合适的损失函数衡量预测精度
  • 梯度计算:调用 backward 自动求导
  • 参数优化:借助 Optimizer 更新权重
  • 学习率调控:配合 Scheduler 实现动态调整

实践建议

学习率(lr)是影响训练效果的关键超参数。过大会导致震荡不收敛,过小则收敛缓慢。虽然 SGD 是基础优化方法,但因其对学习率敏感,实际项目中更常采用 Adam 等自适应优化器,它们能根据参数历史梯度自动调整步长,具有更强的鲁棒性和易用性。

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。