CSV数据读写与多实验结果可视化分析
在机器学习实验中,记录训练过程中的关键指标(如损失值、准确率等)是模型调试和性能评估的重要环节。Python 的 csv 模块提供了便捷的接口用于结构化数据的持久化存储。以下展示如何将迭代训练的成本数据写入 CSV 文件:
import csv
import os
# 假设每轮训练后收集成本数据
cost_log = []
for epoch in range(total_epochs):
loss = compute_loss() # 示例函数
cost_log.append({'Iteration': epoch, 'TrainingLoss': loss})
# 写入CSV文件
log_directory = params['log_dir']
os.makedirs(log_directory, exist_ok=True)
file_path = os.path.join(log_directory, 'training_loss.csv')
field_names = ['Iteration', 'TrainingLoss']
with open(file_path, 'w', newline='', encoding='utf-8') as file:
writer = csv.DictWriter(file, fieldnames=field_names)
writer.writeheader()
writer.writerows(cost_log)
完成数据记录后,可通过 pandas 和 seaborn 对多个实验的结果进行聚合分析与可视化。例如,在强化学习任务中比较不同超参数配置下的收敛行为。
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
import os
from glob import glob
def load_experiment_data(root_path, label=None):
"""
遍历实验目录,加载每个种子运行下的CSV日志
"""
datasets = []
run_id = 0
for subdir in glob(os.path.join(root_path, "*/")):
log_file = os.path.join(subdir, "training_loss.csv")
param_file = os.path.join(subdir, "params.json")
if os.path.exists(log_file) and os.path.exists(param_file):
df = pd.read_csv(log_file)
with open(param_file, 'r') as f:
config = json.load(f)
# 添加元信息列
df['Run'] = run_id
df['Experiment'] = label or config.get('exp_name', 'unknown')
datasets.append(df)
run_id += 1
return pd.concat(datasets, ignore_index=True) if datasets else pd.DataFrame()
def generate_comparison_plot(experiment_dirs, labels=None, metric='TrainingLoss', output_dir='figures'):
"""
绘制多组实验的时序对比图
"""
os.makedirs(output_dir, exist_ok=True)
all_data = []
for path, name in zip(experiment_dirs, labels or experiment_dirs):
data = load_experiment_data(path, label=name)
if not data.empty:
all_data.append(data)
if not all_data:
print("未找到有效数据")
return
combined = pd.concat(all_data, ignore_index=True)
plt.figure(figsize=(14, 7))
sns.set_style("darkgrid", {"axes.facecolor": ".9"})
sns.lineplot(data=combined, x='Iteration', y=metric, hue='Experiment', ci='sd')
plt.title(f'{metric} 跨实验对比', fontsize=16)
plt.xlabel('迭代次数')
plt.ylabel(metric)
plt.legend(title='实验组')
save_path = os.path.join(output_dir, f"{metric}_comparison.png")
plt.savefig(save_path, dpi=200, bbox_inches='tight')
plt.close()
上述流程支持从多个独立训练运行中提取数据,并基于指定指标生成带置信区间的趋势图。通过封装命令行接口,用户可在终端直接调用绘图工具:
python plotter.py --dirs ./data/exp_baseline ./data/exp_tuned \
--labels "Default" "Optimized" \
--metric TrainingLoss
该方法适用于大规模超参搜索后的结果分析,能够高效呈现不同配置对模型训练动态的影响。
