当前位置:首页 > 工具 > 正文内容

CSV数据读写与多实验结果可视化分析

访客 工具 2026年6月3日 1

在机器学习实验中,记录训练过程中的关键指标(如损失值、准确率等)是模型调试和性能评估的重要环节。Python 的 csv 模块提供了便捷的接口用于结构化数据的持久化存储。以下展示如何将迭代训练的成本数据写入 CSV 文件:

import csv
import os

# 假设每轮训练后收集成本数据
cost_log = []
for epoch in range(total_epochs):
    loss = compute_loss()  # 示例函数
    cost_log.append({'Iteration': epoch, 'TrainingLoss': loss})

# 写入CSV文件
log_directory = params['log_dir']
os.makedirs(log_directory, exist_ok=True)
file_path = os.path.join(log_directory, 'training_loss.csv')

field_names = ['Iteration', 'TrainingLoss']
with open(file_path, 'w', newline='', encoding='utf-8') as file:
    writer = csv.DictWriter(file, fieldnames=field_names)
    writer.writeheader()
    writer.writerows(cost_log)

完成数据记录后,可通过 pandasseaborn 对多个实验的结果进行聚合分析与可视化。例如,在强化学习任务中比较不同超参数配置下的收敛行为。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
import os
from glob import glob

def load_experiment_data(root_path, label=None):
    """
    遍历实验目录,加载每个种子运行下的CSV日志
    """
    datasets = []
    run_id = 0
    for subdir in glob(os.path.join(root_path, "*/")):
        log_file = os.path.join(subdir, "training_loss.csv")
        param_file = os.path.join(subdir, "params.json")
        
        if os.path.exists(log_file) and os.path.exists(param_file):
            df = pd.read_csv(log_file)
            with open(param_file, 'r') as f:
                config = json.load(f)
            
            # 添加元信息列
            df['Run'] = run_id
            df['Experiment'] = label or config.get('exp_name', 'unknown')
            datasets.append(df)
            run_id += 1
            
    return pd.concat(datasets, ignore_index=True) if datasets else pd.DataFrame()

def generate_comparison_plot(experiment_dirs, labels=None, metric='TrainingLoss', output_dir='figures'):
    """
    绘制多组实验的时序对比图
    """
    os.makedirs(output_dir, exist_ok=True)
    all_data = []

    for path, name in zip(experiment_dirs, labels or experiment_dirs):
        data = load_experiment_data(path, label=name)
        if not data.empty:
            all_data.append(data)

    if not all_data:
        print("未找到有效数据")
        return

    combined = pd.concat(all_data, ignore_index=True)
    
    plt.figure(figsize=(14, 7))
    sns.set_style("darkgrid", {"axes.facecolor": ".9"})
    sns.lineplot(data=combined, x='Iteration', y=metric, hue='Experiment', ci='sd')
    
    plt.title(f'{metric} 跨实验对比', fontsize=16)
    plt.xlabel('迭代次数')
    plt.ylabel(metric)
    plt.legend(title='实验组')
    
    save_path = os.path.join(output_dir, f"{metric}_comparison.png")
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close()

上述流程支持从多个独立训练运行中提取数据,并基于指定指标生成带置信区间的趋势图。通过封装命令行接口,用户可在终端直接调用绘图工具:

python plotter.py --dirs ./data/exp_baseline ./data/exp_tuned \
                        --labels "Default" "Optimized" \
                        --metric TrainingLoss

该方法适用于大规模超参搜索后的结果分析,能够高效呈现不同配置对模型训练动态的影响。

相关文章

Trojan服务器搭建与配置

一、整体架构(先对齐认知)Clash Meta (PC / iOS / Android)        ↓ TLS   Trojan Server (443)        ↓     InternetTrojan 的核心是: TLS + HTTPS 流量伪装 看起来像正常网站 非常适合...

Tailscale 的详细用法

Tailscale 是一种基于 WireGuard 协议 的 零配置 VPN(虚拟私有网络)服务,让设备之间能够 安全、加密地直接连接,就像它们在同一个本地网络一样。它的核心特点是 简单、安全、跨平台。Tailscale 非常适合 没有公网 IP、两台电脑不在同一局域网 的场景。 简单来说,Tailscale 是什么?Tailscale 是一款让你的各种设备(电脑、服务器、手机...

Clash Tun 模式 导致 爱快(iKuai SD-Wan)内网域名无法访问

一、Clash  DNS 配置dns:  enable: true  listen: 0.0.0.0:53  ipv6: true  enhanced-mode: redir-host  nameserver:    - 223.5.5.5    - 223.6.6.6iKuai 内网域名 ...

深入解析Node.js运行环境与异步I/O架构

深入解析Node.js运行环境与异步I/O架构

核心定义与价值Node.js本质上是一个JavaScript运行环境,而非编程语言或应用框架。它赋予了JavaScript脱离浏览器在服务端、命令行工具及网络应用中执行的能力。其核心意义在于:用单一语言打通前后端开发壁垒。基于事件驱动与非阻塞I/O的架构特性,Node.js在处理API网关、实时通信及微服务等I/O密集型场景时表现卓越,已成为现代后端工程的主流选择。浏览器沙箱限制1995年Java...

ADO.NET SQL参数化查询的最佳实践

在 ADO.NET 中执行 SQL 查询时,参数化查询是一种关键的安全措施和性能优化手段。它通过将 SQL 命令和用户提供的数据分开处理,有效防止了 SQL 注入攻击,并有助于数据库缓存执行计划。下面总结了几种常用的参数化查询方式。 1. 使用 SqlParameter 对象(推荐) 这是最推荐的参数化查询方式。通过显式创建 SqlParameter 对象,您可以精确控制参数的类...

基于ELK的日志集中化分析系统搭建

构建统一日志管理平台的必要性 在分布式架构中,各服务节点独立运行,日志分散存储于不同主机。传统通过命令行工具如grep、awk逐个检索日志的方式,在数据量庞大时效率极低,难以实现快速定位问题。为提升运维效率,需建立集中式日志处理体系,具备日志采集、传输、存储、分析与告警能力。 ELK技术栈核心组件解析 Elasticsearch:分布式搜索引擎,支持全文检索、实时数据分析和高可用集群部署,...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。