当前位置:首页 > 技术 > 正文内容

卷积神经网络注意力热图可视化技术详解

访客 技术 2026年6月10日 1

一、注意力热图概述

注意力热图是一种直观展现神经网络在图像处理过程中关注区域的可视化技术。该技术能够帮助研究人员理解模型决策机制,从而为模型优化提供重要参考。在计算机视觉领域,注意力热图被广泛用于图像分类、目标检测、语义分割等任务。

生成注意力热图的核心思路是利用神经网络的输出特征图。通常做法是选取网络最后一个卷积层的特征图,将其 resize 到输入图像尺寸后,通过特定函数与原图叠加即可得到热力图。本次技术分享将详细介绍类激活映射(Class Activation Mapping,CAM)方法的原理与实现。

二、类激活映射(CAM)算法原理

CAM 方法源自论文《Learning Deep Features for Discriminative Localization》,其核心思想是通过卷积层特征图与全连接层权重的加权组合来生成注意力热图。具体实现步骤如下:

步骤一:提取特征图
从神经网络中获取最后一层卷积的输出特征图,维度为 [B, C, H, W],其中 B 表示 batch size,C 为通道数,H 和 W 分别为特征图的高和宽。当输入单张图像时,B=1。

步骤二:获取分类权重
提取训练完成模型的分类层权重矩阵,该权重的输入维度需与特征图通道数严格对应。

步骤三:加权融合
将各通道特征图与对应分类权重进行加权求和运算,最终得到各类别的注意力热图。

三、PyTorch 环境下 CAM 的代码实现

以下是完整的 CAM 实现代码,包含数据预处理、特征提取、热图生成与可视化等关键环节。

(一)模型加载与参数配置

import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from PIL import Image

# 导入自定义网络架构
from model import model

NUM_CLASSES = 5

# 加载预训练权重
pretrained_model = model(num_classes=NUM_CLASSES)
pretrained_model.load_state_dict(torch.load('weights.pth', map_location=lambda storage, loc: storage))

# 分离特征提取器与分类头
feature_extractor = nn.Sequential(*list(pretrained_model.children())[:-2])
fc_weights = pretrained_model.state_dict()['classifier.weight'].cpu().numpy()

# 类别标签映射
class_labels = {0: 'car', 1: 'bird', 2: 'tree', 3: 'sky', 4: 'person'}

pretrained_model.eval()
feature_extractor.eval()

(二)图像预处理与预测推理

# 定义数据转换流水线
data_transform = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val": transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# 加载测试图像
test_image_path = '/data/test.jpg'
_, image_name = os.path.split(test_image_path)
features_cache = []

input_image = Image.open(test_image_path).convert('RGB')
input_tensor = data_transform['val'](input_image).unsqueeze(0)  # [1,3,224,224]

# 提取卷积特征
conv_features = feature_extractor(input_tensor).detach().cpu().numpy()  # [1,960,7,7]

# 模型预测
logits = pretrained_model(input_tensor)  # [1,5]
probabilities = torch.nn.functional.softmax(logits, dim=1).data.squeeze()

# 排序获取预测结果
prob_sorted, idx_sorted = probabilities.sort(0, True)
prob_sorted = prob_sorted.cpu().numpy()
idx_sorted = idx_sorted.cpu().numpy()

# 输出预测结果
for i in range(NUM_CLASSES):
    print('{:.3f} -> {}'.format(prob_sorted[i], class_labels[idx_sorted[i]]))

(三)生成类激活热图

def generate_cam(feature_maps, fc_weights, class_indices):
    batch_size, channels, height, width = feature_maps.shape
    cam_results = []
    
    for idx in class_indices:
        # 将特征图展平为二维矩阵
        flattened_features = feature_maps.reshape((channels, height * width))
        
        # 计算加权激活值
        cam = fc_weights[idx].dot(flattened_features)
        cam = cam.reshape(height, width)
        
        # 归一化处理
        cam_normalized = (cam - cam.min()) / (cam.max() - cam.min())
        cam_uint8 = np.uint8(255 * cam_normalized)
        
        cam_results.append(cam_uint8)
    
    return cam_results

# 生成最高概率类别的热图
attention_maps = generate_cam(conv_features, fc_weights, idx_sorted)
print('{} 预测结果: {}'.format(image_name, class_labels[idx_sorted[0]]))

(四)热图可视化与图像保存

# 读取原始图像
original_img = cv2.imread(test_image_path)
img_height, img_width, _ = original_img.shape

# 调整热图尺寸并应用色彩映射
resized_heatmap = cv2.resize(attention_maps[0], (img_width, img_height))
colored_heatmap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)

# 叠加原始图像与热图
superimposed_result = colored_heatmap * 0.3 + original_img * 0.5

# 添加预测标签文字
prediction_text = '%s %.2f%%' % (class_labels[idx_sorted[0]], prob_sorted[0] * 100)
cv2.putText(superimposed_result, prediction_text, (210, 40), 
            fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9,
            color=(123, 222, 238), thickness=2, lineType=cv2.LINE_AA)

# 保存结果图像
output_directory = r'/data/heatmap/'
if not os.path.exists(output_directory):
    os.makedirs(output_directory)

base_name = image_name.split(".")[-2]
cv2.imwrite(os.path.join(output_directory, base_name + '_heatmap.jpg'), superimposed_result)

四、实验结果分析

通过上述实现流程,成功生成了输入图像的注意力热图。实验结果表明,注意力热图能够准确标注出神经网络在进行类别预测时重点关注的图像区域。

以车辆识别任务为例,热力图的高亮区域主要集中在车辆轮廓及关键部件位置,验证了模型能够有效捕获目标的判别性特征。这一可视化结果为模型可解释性研究提供了有力支持。

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。