复现论文代码的实践指南
一、环境配置
1. 创建独立开发环境
conda create --name drgcnntest python=3.8.18
2. 激活虚拟环境
conda activate drgcnntest
注意事项:
在PyCharm中若终端显示PS而非base环境,需修改shell路径为cmd.exe。操作路径:工具→部署→配置,设置远程连接参数后重启终端。
3. 安装深度学习框架
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 cpuonly -c pytorch
同步安装依赖包时建议使用requirements.txt文件管理。
二、代码解析
1. 图像预处理模块
eye_pre_process模块包含核心处理流程:
- 命令行参数解析
- 多线程图像处理
- 智能裁剪与保存
参数配置示例:
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, default=r'./data/raw')
parser.add_argument('--output_dir', type=str, default=r'./data/processed')
parser.add_argument('--image_size', type=int, default=512)
parser.add_argument('--num_workers', type=int, default=8)
核心处理逻辑:
def process_image(img_path, output_path, target_size):
with Image.open(img_path) as img:
# 自适应裁剪逻辑
if img.width > 1.2 * img.height:
left_max = np.max(img.crop((0, 0, img.width//32, img.height)), axis=(1,2))
right_max = np.max(img.crop((img.width - img.width//32, 0, img.width, img.height)), axis=(1,2))
bg_threshold = np.maximum(left_max, right_max) + 10
foreground_mask = (np.array(img) > bg_threshold).astype(np.uint8)
bbox = Image.fromarray(foreground_mask).getbbox()
if not bbox or (bbox[2]-bbox[0] < 0.8*img.height):
bbox = calculate_square_bbox(img)
cropped = img.crop(bbox).resize(target_size)
cropped.save(output_path, quality=100)
2. 模型训练框架
Encoder模块包含完整的训练流程:
- 配置加载与路径管理
- 模型参数统计
- 数据集生成
- 训练/验证/测试流程
核心训练流程:
def train(cfg, model, train_loader, val_loader):
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
criterion = nn.CrossEntropyLoss()
for epoch in range(cfg.epochs):
model.train()
train_loss = 0
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
# 验证阶段
model.eval()
val_loss = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
# 保存最佳模型
if val_loss < best_val_loss:
save_checkpoint(model.state_dict(), cfg.save_path)
评估模块实现多指标评估:
class PerformanceEvaluator:
def __init__(self, criterion, num_classes):
self.criterion = criterion
self.num_classes = num_classes
self.metrics = {
'accuracy': Accuracy(),
'precision': Precision(num_classes),
'recall': Recall(num_classes),
'f1_score': F1Score(num_classes)
}
def evaluate(self, model, dataloader):
model.eval()
total_loss = 0
for images, labels in dataloader:
outputs = model(images)
loss = self.criterion(outputs, labels)
total_loss += loss.item()
for metric in self.metrics.values():
metric.update(outputs, labels)
return {
'loss': total_loss / len(dataloader),
'metrics': {k: m.compute() for k, m in self.metrics.items()}
}