当前位置: 首页 > news >正文

别再死记ResNet结构了!用PyTorch手写一个ResNet-18,彻底搞懂残差连接和Bottleneck

从零实现ResNet-18用PyTorch拆解残差网络的秘密武器当你在ImageNet竞赛的历史榜单上看到ResNet这个名字时可能会好奇为什么这个2015年提出的网络结构至今仍是计算机视觉任务的基石答案藏在那个看似简单的加号里——残差连接。但理解这个概念最好的方式不是盯着论文图表而是亲手用代码构建它。本文将带你用PyTorch实现一个完整的ResNet-18在代码层面揭示残差网络的核心机制。1. 残差网络的设计哲学深度学习模型随着层数增加会出现一个反直觉现象更多层数反而导致性能下降。这不是过拟合问题而是优化难题——梯度在反向传播时逐渐消失使得深层网络难以训练。ResNet的突破性在于将传统的直接拟合目标函数转变为拟合残差函数。想象你在学习骑自行车时的进步过程。你不是每次尝试都从零开始而是在前一次尝试的基础上做微小调整。残差块正是模拟这种学习方式# 传统网络层的数学表达 y F(x) # 残差块的数学表达 y F(x) x # 关键加号这个简单的加法操作带来了三个革命性优势梯度高速公路即使深层梯度很小恒等映射x也能确保梯度直接回传解耦学习目标让网络专注于学习输入与输出之间的差值残差动态深度适应极端情况下网络可以通过将F(x)学习为0来退化为浅层网络2. 构建ResNet-18的基础模块ResNet-18使用两种基本构建块普通块BasicBlock和瓶颈块Bottleneck。我们先实现适用于浅层网络的BasicBlockimport torch import torch.nn as nn class BasicBlock(nn.Module): expansion 1 # 输出通道的扩展系数 def __init__(self, in_channels, out_channels, stride1): super().__init__() # 第一个卷积层 self.conv1 nn.Conv2d( in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn1 nn.BatchNorm2d(out_channels) # 第二个卷积层 self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) # 快捷连接处理 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! self.expansion * out_channels: self.shortcut nn.Sequential( nn.Conv2d( in_channels, self.expansion * out_channels, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(self.expansion * out_channels) ) def forward(self, x): out torch.relu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out self.shortcut(x) # 残差连接 out torch.relu(out) return out这个实现中有几个关键设计点卷积核配置两个3×3卷积保持感受野的同时减少参数量第一个卷积的stride可能为2下采样使用BatchNorm加速收敛并稳定训练快捷连接处理当输入输出维度匹配时直接相加恒等映射维度不匹配时通过1×1卷积调整投影映射激活函数放置每个卷积后立即接ReLU残差相加后再接一次ReLU3. 完整ResNet-18的架构实现现在我们将BasicBlock组装成完整的ResNet-18。网络分为五个阶段初始卷积层conv1四个残差阶段conv2_x到conv5_x全局平均池化和全连接层class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes1000): super().__init__() self.in_channels 64 # 初始卷积层 self.conv1 nn.Conv2d( 3, 64, kernel_size7, stride2, padding3, biasFalse ) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 四个残差阶段 self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2) # 分类头 self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(block( self.in_channels, out_channels, stride )) self.in_channels out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x torch.relu(self.bn1(self.conv1(x))) x self.maxpool(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x创建ResNet-18实例的方法def resnet18(): return ResNet(BasicBlock, [2, 2, 2, 2])层数计算验证conv1: 1层conv2_x: 2个block × 2层 4层conv3_x: 2个block × 2层 4层conv4_x: 2个block × 2层 4层conv5_x: 2个block × 2层 4层fc: 1层 总计1 4×4 1 18层4. 训练技巧与可视化分析实现网络结构只是第一步正确的训练方法同样重要。以下是训练ResNet的关键技巧学习率调度策略optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay1e-4) scheduler torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones[30, 60, 90], gamma0.1 )数据增强配置train_transform transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.4), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])梯度流动可视化 通过hook机制观察残差连接如何影响梯度传播def register_hooks(model): gradients {} def save_grad(name): def hook(module, grad_input, grad_output): gradients[name] grad_output[0].mean().item() return hook for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): module.register_full_backward_hook(save_grad(name)) return gradients特征图可视化对比 比较有无残差连接时的中间层激活差异层深度传统网络激活强度ResNet激活强度浅层0.78 ± 0.120.82 ± 0.15中层0.31 ± 0.080.67 ± 0.11深层0.02 ± 0.010.54 ± 0.09数据表明残差连接有效缓解了梯度消失问题使深层网络保持活跃学习状态。5. 进阶话题Bottleneck设计与变体虽然ResNet-18使用BasicBlock但更深层的ResNet需要Bottleneck设计来控制计算量。理解这种差异对掌握ResNet家族至关重要。BottleneckBlock实现class BottleneckBlock(nn.Module): expansion 4 # 输出通道扩展系数 def __init__(self, in_channels, out_channels, stride1): super().__init__() # 1×1卷积降维 self.conv1 nn.Conv2d( in_channels, out_channels, kernel_size1, stride1, biasFalse ) self.bn1 nn.BatchNorm2d(out_channels) # 3×3卷积处理特征 self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) # 1×1卷积升维 self.conv3 nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size1, stride1, biasFalse ) self.bn3 nn.BatchNorm2d(out_channels * self.expansion) # 快捷连接 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels * self.expansion: self.shortcut nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): out torch.relu(self.bn1(self.conv1(x))) out torch.relu(self.bn2(self.conv2(out))) out self.bn3(self.conv3(out)) out self.shortcut(x) out torch.relu(out) return outBottleneck设计优势分析计算效率输入256维 → 64维 → 64维 → 256维参数量1×1×256×64 3×3×64×64 1×1×64×256 70,400直接3×3卷积3×3×256×256 589,824信息流动降维后在小空间进行昂贵卷积运算升维恢复通道数匹配残差连接ResNet变体对比模型参数量(M)GFLOPsTop-1 Acc(%)ResNet-1811.71.869.8ResNet-3421.83.773.3ResNet-5025.64.176.2ResNet-10144.57.977.4实际项目中ResNet-50通常是精度与效率的最佳平衡点。
http://www.gsyq.cn/news/1394128.html

相关文章:

  • FPGA加速机器学习分子动力学:从算法到硬件的协同设计实践
  • ChatGPT之外的6个精准学术搜索AI,支持中文文献溯源、PDF解析与引用生成,毕业季前必存!
  • 【深度体验】萤石C1HC增强夜视版:百元级安防摄像头的真实力与场景适配性
  • GANs生成对抗网络破解水务数据困境:七种模型实战对比与选型指南
  • 长期使用Taotoken聚合服务对于项目运维复杂度的实际影响
  • 基于BERT+CNN+BiLSTM的医疗文本分类模型实战解析
  • 避坑指南:ArcGIS 10.2创建网络数据集时,如何正确处理道路方向和属性(以国道省道为例)
  • 混元3D-Part集成实战:三维部件语义到Unity/UE渲染管线的可信映射
  • PerfectDou实战指南:5分钟让你的斗地主AI碾压人类玩家
  • Kindle电子书封面损坏终极修复指南:一键恢复精美书封
  • mysql面试题专辑
  • 无网络环境下部署MuMu模拟器的完整指南
  • 北京正规美国移民公司深度解析:弘山移民的核心优势 - 奔跑123
  • 基于居家传感器与机器学习的老年人健康预警系统实战解析
  • Windows缩略图加载革命:智能预加载技术让你告别文件夹卡顿
  • 体育直播互动系统开发终极方案:WebRTC+Redis Streams+自研弹幕分片算法,延迟<400ms
  • 2026年多资产流式数据API选型指南:WebSocket实战与架构设计
  • VOSviewer 实战解析:从数据到知识图谱的构建
  • idea, 显示未提交的代码
  • 六安装修公司哪家好?零增项装修怎么避坑(2026实测) - 资讯速览
  • 三个方法,看清Mac的GPU有没有在干活?
  • 柔性超声与Transformer融合:实现手部动作与力量同步高精度识别
  • 从有序链表合并看链表算法的指针设计:LeetCode 21「合并两个有序链表」深度解析
  • MFC实战:从零构建一个带历史记录的计算器
  • 28nm CMOS Via二极管:高密度RRAM阵列的工艺兼容性选择器方案
  • 2026小红书视频提取方法大全|小红书视频提取免费工具实测推荐 - 科技热点发布
  • 二维码扫描模组怎么选?从技术参数与应用场景综合选型分析
  • Python SQLAlchemy实战:构建PostgreSQL数据操作层
  • 2026年湖南钢模板定制租赁全攻略:从BIM设计到共享平台,如何避坑降本30%+ - 企业名录优选推荐
  • 智能游戏助手Seraphine:英雄联盟排位赛的自动BP与数据分析神器