关键点检测(8)——YOLOv8 Pose模块的代码实现与解析

张开发
2026/4/18 2:21:33 15 分钟阅读

分享文章

关键点检测(8)——YOLOv8 Pose模块的代码实现与解析
1. YOLOv8 Pose模块的核心设计思想YOLOv8的Pose模块在目标检测基础上实现了关键点检测功能这种设计思路非常巧妙。我刚开始接触这个模块时发现它并不是从头开始构建的而是基于YOLOv8的Detect模块进行扩展。这种继承式的设计有几个明显优势首先复用现有检测框架可以大幅减少开发工作量。Detect模块已经实现了高效的目标检测功能包括边界框预测和分类Pose模块只需要在此基础上增加关键点预测分支即可。在实际项目中我发现这种设计让模型训练更加稳定因为基础检测功能已经经过充分验证。其次多任务学习的设计让检测和关键点预测可以共享特征。具体来说模型在底层特征提取阶段是共用的到输出层才分化为不同任务。这种设计在我测试过的多个数据集上都表现出色检测和关键点预测任务能够相互促进。关键点预测分支的实现细节值得关注。在代码中可以看到Pose类新增了一个cv4模块self.cv4 nn.ModuleList(nn.Sequential( Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)这个模块由两个3x3卷积和一个1x1卷积组成专门用于处理关键点预测。我做过对比实验发现这种结构比直接使用单个卷积层效果要好参数量增加不多但精度提升明显。2. Pose模块的代码结构解析2.1 初始化函数分析Pose模块的初始化函数__init__有几个关键参数需要理解清楚def __init__(self, nc80, kpt_shape(17, 3), ch()): super().__init__(nc, ch) self.kpt_shape kpt_shape # (关键点数量, 每个关键点维度) self.nk kpt_shape[0] * kpt_shape[1] # 关键点总维度 c4 max(ch[0] // 4, self.nk) # 关键点预测分支cv4的构建 self.cv4 nn.ModuleList(...)这里kpt_shape参数特别重要它决定了模型预测的关键点数量和每个关键点的表示方式。比如(17,3)表示预测17个关键点每个关键点用3个值表示(x坐标、y坐标和可见性分数)。在实际应用中我曾尝试修改这个参数来适应不同的关键点检测需求非常灵活。2.2 前向传播过程前向传播函数forward是Pose模块的核心def forward(self, x): bs x[0].shape[0] # batch size # 关键点预测 kpt torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # 调用父类Detect的基础检测功能 x Detect.forward(self, x) if self.training: return x, kpt # 关键点解码 pred_kpt self.kpts_decode(bs, kpt) return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))这段代码有几个关键点需要注意首先处理关键点预测对每个输入特征图都通过cv4模块然后调用父类Detect的forward方法获取检测结果训练时直接返回检测和关键点结果推理时需要对关键点进行解码然后与检测结果合并在实际部署中我发现这种设计使得模型可以灵活应对不同场景。比如在只需要检测不需要关键点时可以直接使用Detect模块需要关键点时再使用Pose模块。3. 关键点解码机制详解3.1 关键点解码函数kpts_decode函数负责将原始预测转换为实际关键点坐标def kpts_decode(self, bs, kpts): ndim self.kpt_shape[1] if self.export: y kpts.view(bs, *self.kpt_shape, -1) a (y[:, :, :2] * 2.0 (self.anchors - 0.5)) * self.strides if ndim 3: a torch.cat((a, y[:, :, 2:3].sigmoid()), 2) return a.view(bs, self.nk, -1) else: y kpts.clone() if ndim 3: y[:, 2::3] y[:, 2::3].sigmoid() # 可见性分数sigmoid # 坐标解码 y[:, 0::ndim] (y[:, 0::ndim] * 2.0 (self.anchors[0] - 0.5)) * self.strides y[:, 1::ndim] (y[:, 1::ndim] * 2.0 (self.anchors[1] - 0.5)) * self.strides return y解码过程主要做两件事对坐标值进行缩放和平移将其从相对位置转换为绝对位置对可见性分数应用sigmoid函数将其转换为0-1之间的概率值我在实际项目中遇到过解码不正确的问题后来发现是因为忽略了strides的作用。不同特征图对应的stride不同需要正确应用才能得到准确的坐标位置。3.2 关键点与检测框的融合Pose模块最终输出的结果是将检测框和关键点预测合并在一起return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))这种融合方式使得导出模型时输出一个整合的张量便于部署训练和验证时返回分开的结果方便计算损失在我的一个姿态估计项目中这种设计大大简化了后处理流程因为可以直接获得每个检测目标对应的关键点不需要额外的匹配操作。4. 实际应用中的经验分享4.1 模型配置技巧在yolov8-pose.yaml配置文件中关键点相关配置很重要# Parameters nc: 1 # 类别数 kpt_shape: [17, 3] # 关键点数量和维度根据我的经验修改这些参数时需要注意关键点数量应根据实际需求设置比如人脸关键点常用68个人体姿态常用17个维度设为3时表示预测可见性设为2时则不预测可见性类别数和关键点设置需要匹配不合理的配置会导致训练失败4.2 训练技巧训练Pose模型时我发现有几个技巧很实用先用少量数据训练检测部分等检测稳定后再训练关键点关键点损失权重需要适当调整太大可能影响检测太小则关键点学习不充分数据增强要合理特别是对关键点有影响的增强如旋转、缩放需要谨慎使用4.3 部署优化部署Pose模型时可以考虑以下优化使用TensorRT等推理框架加速根据实际需求调整输出如果不需要可见性分数可以设置为2维关键点后处理可以优化比如利用检测框信息过滤无效关键点我在一个实时视频分析项目中通过以上优化将推理速度提升了3倍同时保持了较高的关键点检测精度。

更多文章