Matplotlib 数据可视化实战指南
Matplotlib 是 Python 生态中最核心的可视化工具,与 NumPy 深度整合,为机器学习中的数据探索提供直观支持。本文聚焦于高频使用场景,帮助开发者快速建立可视化能力。
一、快速上手
1. 基础曲线绘制
以神经网络中常见的 Tanh 激活函数为例,演示最基本的绘图流程:
import matplotlib.pyplot as plt
import numpy as np
# 生成输入数据
t = np.linspace(-8, 8, 800)
# Tanh 函数计算
h = (np.exp(t) - np.exp(-t)) / (np.exp(t) + np.exp(-t))
plt.plot(t, h)
plt.show()
plot() 定义曲线,show() 渲染窗口。
2. 图表元素配置
t = np.linspace(-8, 8, 800)
h = (np.exp(t) - np.exp(-t)) / (np.exp(t) + np.exp(-t))
# 坐标轴边界
plt.xlim(-6, 6)
plt.ylim(-1.2, 1.2)
# 轴标签与标题
plt.xlabel("输入值")
plt.ylabel("输出值")
plt.title("Tanh 激活函数")
# 网格与参考线
plt.grid(ls=":", c="gray")
plt.axhline(y=0, c="orange", ls="--", lw=1.5)
plt.axvline(x=0, c="orange", ls="--", lw=1.5)
plt.plot(t, h)
plt.savefig("tanh_demo.png", dpi=300, bbox_inches="tight")
3. 多曲线对比
t = np.linspace(-5, 5, 500)
# 三条不同曲线
y1 = np.tanh(t)
y2 = np.maximum(0, t) # ReLU
y3 = 1 / (1 + np.exp(-t)) # Sigmoid
plt.xlim(-4, 4)
plt.ylim(-1.5, 2)
plt.plot(t, y1, c="#2E86AB", ls="-", lw=2, label="Tanh")
plt.plot(t, y2, c="#A23B72", ls="-.", lw=2, label="ReLU")
plt.plot(t, y3, c="#F18F01", ls="--", lw=2, label="Sigmoid")
plt.legend(loc="upper left", frameon=True, shadow=True)
plt.show()
4. 子图布局
使用 subplots 创建矩阵式布局:
t = np.linspace(-3*np.pi, 3*np.pi, 600)
wave_a = np.sin(t**2)
wave_b = np.cos(t) * np.exp(-t**2/10)
wave_c = np.tanh(np.sin(t*3))
wave_d = np.abs(np.sinc(t/np.pi))
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes[0,0].plot(t, wave_a, 'c-')
axes[0,0].set_title('阻尼振荡 A')
axes[0,1].fill_between(t, wave_b, alpha=0.6, color='coral')
axes[0,1].set_title('衰减波 B')
axes[1,0].plot(t, wave_c, 'g:', lw=2)
axes[1,0].set_title('调制信号 C')
axes[1,1].scatter(t[::10], wave_d[::10], c='purple', s=20, alpha=0.7)
axes[1,1].set_title('采样点 D')
fig.subplots_adjust(wspace=0.3, hspace=0.4)
fig.suptitle("多信号分析面板", fontsize=14, y=0.98)
5. 常用图表类型
plt.rcParams['font.sans-serif'] = ['SimHei']
fig, axes = plt.subplots(3, 2, figsize=(10, 12))
# 垂直条形图
categories = ['A类', 'B类', 'C类', 'D类', 'E类']
heights = np.random.randint(10, 100, 5)
axes[0,0].bar(categories, heights, color='steelblue', edgecolor='navy')
axes[0,0].set_title('分类统计')
# 直方图
samples = np.random.gamma(2, 2, 1000)
axes[0,1].hist(samples, bins=40, color='seagreen', edgecolor='white')
axes[0,1].set_title('分布直方')
# 饼图
segments = ['产品X', '产品Y', '产品Z', '其他']
ratios = [35, 25, 30, 10]
explode = (0.05, 0, 0.05, 0)
axes[1,0].pie(ratios, explode=explode, labels=segments, autopct='%1.1f%%',
colors=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'])
axes[1,0].set_title('市场份额')
# 棒棒糖图
idx = np.arange(12)
vals = np.random.randint(20, 80, 12)
axes[1,1].stem(idx, vals, linefmt='-.', markerfmt='D', basefmt=' ')
axes[1,1].set_title('离散趋势')
# 气泡散点
np.random.seed(42)
x_pos = np.random.randn(60)
y_pos = np.random.randn(60)
sizes = np.abs(x_pos * y_pos) * 300
colors = np.random.rand(60)
axes[2,0].scatter(x_pos, y_pos, s=sizes, c=colors, cmap='viridis', alpha=0.7)
axes[2,0].set_title('多维散点')
# 极坐标
fig.delaxes(axes[2,1])
polar_ax = fig.add_subplot(3, 2, 6, projection='polar')
theta = np.linspace(0, 2*np.pi, 100)
r = 1 + 0.5 * np.cos(4*theta)
polar_ax.plot(theta, r, lw=2)
polar_ax.fill(theta, r, alpha=0.3)
polar_ax.set_title('极坐标玫瑰线')
plt.tight_layout()
6. 参数缩写速查
x = np.linspace(-6, 6, 100)
y = np.sinc(x)
plt.plot(x, y, c='darkred', ls='-', lw=3, marker='s', ms=8,
mfc='yellow', mec='black', mew=1.5, label='Sinc函数')
plt.legend()
| 缩写 | 全称 | 说明 |
|---|---|---|
c | color | 线条颜色 |
ls | linestyle | 线型:- -- -. : |
lw | linewidth | 线宽数值 |
marker | marker | 标记样式:o s ^ D * |
ms | markersize | 标记大小 |
mfc | markerfacecolor | 标记填充色 |
mec | markeredgecolor | 标记边框色 |
mew | markeredgewidth | 标记边框宽 |
二、进阶技巧
1. 注释与标注
fig, ax = plt.subplots(figsize=(9, 6))
t = np.linspace(0, 4*np.pi, 500)
signal = np.sin(t) * np.exp(-t/10)
ax.plot(t, signal, lw=2, color='navy')
# 标记峰值点
peak_idx = np.argmax(signal[:200])
peak_x, peak_y = t[peak_idx], signal[peak_idx]
ax.annotate(f'峰值: ({peak_x:.2f}, {peak_y:.2f})',
xy=(peak_x, peak_y), xycoords='data',
xytext=(peak_x+2, peak_y+0.3), textcoords='data',
arrowprops=dict(arrowstyle='->', color='red', lw=1.5),
fontsize=11, color='darkred')
# 文本框注释
ax.text(8, 0.5, '衰减正弦波\ny = sin(t)·e^(-t/10)',
fontsize=12, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.set_xlim(0, 14)
ax.set_ylim(-1, 1.2)
2. 三维曲面
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10, 8))
ax3d = fig.add_subplot(111, projection='3d')
u = np.linspace(-4, 4, 80)
v = np.linspace(-4, 4, 80)
U, V = np.meshgrid(u, v)
W = np.sin(np.sqrt(U**2 + V**2))
surf = ax3d.plot_surface(U, V, W, cmap='plasma',
linewidth=0, antialiased=True)
ax3d.set_zlim(-1.2, 1.2)
fig.colorbar(surf, shrink=0.5, aspect=10, label='幅值')
3. 地理数据叠加
import matplotlib.image as mpimg
# 模拟城市房价数据
np.random.seed(123)
n_points = 300
lons = np.random.uniform(-122.5, -121.5, n_points)
lats = np.random.uniform(37.5, 38.5, n_points)
prices = np.random.lognormal(12.5, 0.4, n_points) # 房价中位数
populations = np.random.lognormal(8, 0.8, n_points) # 人口
fig, ax = plt.subplots(figsize=(10, 10))
# 背景地图占位(实际使用需加载真实地图)
ax.set_facecolor('lightgray')
# 房价热力散点
scatter = ax.scatter(lons, lats,
s=populations/50,
c=prices, cmap='YlOrRd',
alpha=0.6, edgecolors='black', linewidth=0.5)
ax.set_xlim(-122.6, -121.4)
ax.set_ylim(37.4, 38.6)
ax.set_xlabel('经度')
ax.set_ylabel('纬度')
cbar = plt.colorbar(scatter)
cbar.set_label('房价中位数 (USD)', rotation=270, labelpad=20)
4. 等高线填充
def potential_field(x, y):
return np.sin(x) * np.cos(y) + 0.3 * np.sin(3*x) * np.cos(3*y)
x_grid = np.linspace(-np.pi, np.pi, 200)
y_grid = np.linspace(-np.pi, np.pi, 200)
X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
Z_values = potential_field(X_grid, Y_grid)
plt.figure(figsize=(9, 7))
contours = plt.contourf(X_grid, Y_grid, Z_values, levels=25, cmap='RdBu_r')
plt.colorbar(contours, label='势能值')
plt.contour(X_grid, Y_grid, Z_values, levels=25, colors='black', linewidths=0.3)
plt.title('二维势场分布')
5. 动态可视化
from matplotlib import animation
fig, ax = plt.subplots(figsize=(8, 4))
ax.set_xlim(0, 4*np.pi)
ax.set_ylim(-2, 2)
line_obj, = ax.plot([], [], lw=2, color='forestgreen')
fill_region = ax.fill_between([], [], alpha=0.3)
def frame_init():
line_obj.set_data([], [])
return line_obj,
def frame_update(frame):
x_data = np.linspace(0, 4*np.pi, 400)
y_data = 1.5 * np.sin(x_data + frame/10) * np.exp(-x_data/15)
line_obj.set_data(x_data, y_data)
# 更新填充区域
ax.collections.clear()
ax.fill_between(x_data, 0, y_data, alpha=0.3, color='forestgreen')
return line_obj,
anim = animation.FuncAnimation(fig, frame_update, init_func=frame_init,
frames=200, interval=50, blit=False)
# 保存为 GIF(需安装 pillow)
# anim.save('wave_propagation.gif', writer='pillow', fps=30)
plt.show()
Matplotlib 的学习曲线平缓但功能纵深极大。建议从简单曲线开始,逐步掌握子图布局、样式定制和交互元素,最终形成符合个人风格的可视化流程。