当前位置:首页 > 技术 > 正文内容

PyTorch自动微分机制解析

访客 技术 2026年6月3日 1

自动微分核心组件

PyTorch的自动微分系统基于Tensor和Function两类对象构建有向无环图(DAG),采用动态图机制在每次前向传播时重建计算图。

Tensor关键属性

  • requires_grad:控制梯度计算开关
    • 新建Tensor时指定,默认False
    • 设为True时依赖节点自动继承
    • 使用detach()torch.no_grad()可禁用梯度跟踪
  • grad_fn:记录创建该Tensor的运算函数,叶子节点为None
  • grad:存储梯度值的缓存,非叶子节点在反向传播后自动释放

反向传播函数

backward()接收与调用Tensor同维度的参数(标量可省略),计算梯度并累加至grad属性。设置retain_graph=True可保留计算图进行多次反向传播。

标量反向传播示例

import torch

# 初始化输入和参数
input_val = torch.tensor([3.0])
weight = torch.randn(1, requires_grad=True)
bias = torch.randn(1, requires_grad=True)

# 前向计算
intermediate = weight * input_val
output = intermediate + bias

# 属性验证
print(f"叶子节点grad状态: {input_val.requires_grad}, {weight.requires_grad}, {bias.requires_grad}")
print(f"运算节点grad状态: {intermediate.requires_grad}, {output.requires_grad}")
print(f"梯度函数: {intermediate.grad_fn}, {output.grad_fn}")

# 反向传播
output.backward()
print(f"参数梯度: {weight.grad}, {bias.grad}")

非标量反向传播

处理非标量输出时需提供gradient参数进行向量-雅可比乘积:

x = torch.tensor([[1.0, 2.0]], requires_grad=True)
J_matrix = torch.zeros(2, 2)

# 定义多元函数
y = torch.empty(1, 2)
y[0, 0] = x[0, 0]**2 + 4*x[0, 1]
y[0, 1] = x[0, 1]**3 + 2*x[0, 0]

# 计算雅可比矩阵
y.backward(torch.tensor([[1.0, 0.0]]), retain_graph=True)
J_matrix[0] = x.grad
x.grad.zero_()

y.backward(torch.tensor([[0.0, 1.0]]))
J_matrix[1] = x.grad
print(J_matrix)

回归任务实战

import torch
import matplotlib.pyplot as plt

# 生成数据集
torch.manual_seed(42)
features = torch.unsqueeze(torch.linspace(-2, 2, 200), 1)
target = 2.5 * features**2 + 1.8 + 0.2*torch.randn(features.size())

# 初始化模型参数
coefficient = torch.randn(1, 1, requires_grad=True)
intercept = torch.randn(1, 1, requires_grad=True)

# 训练配置
learning_rate = 0.005
epochs = 1000

# 训练循环
for _ in range(epochs):
    predictions = features**2 @ coefficient + intercept
    loss = (0.5 * (predictions - target)**2).sum()
    
    loss.backward()
    with torch.no_grad():
        coefficient -= learning_rate * coefficient.grad
        intercept -= learning_rate * intercept.grad
        coefficient.grad.zero_()
        intercept.grad.zero_()

# 结果可视化
plt.scatter(features.numpy(), target.numpy())
plt.plot(features.numpy(), predictions.detach().numpy(), 'r-')
plt.show()

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。