HGNN代码架构解析:从数据加载到模型训练的完整流程
HGNN代码架构解析:从数据加载到模型训练的完整流程
【免费下载链接】HGNNHypergraph Neural Networks (AAAI 2019)项目地址: https://gitcode.com/gh_mirrors/hgn/HGNN
Hypergraph Neural Networks (HGNN) 是一种创新的深度学习框架,专为处理高阶数据相关性而设计。本文将带你深入了解HGNN项目的代码架构,从数据加载到模型训练的完整流程,帮助你快速掌握这一强大工具的使用方法。
项目架构概览
HGNN项目采用模块化设计,主要包含以下几个核心目录:
- config/: 配置文件目录,包含项目的核心参数设置
- datasets/: 数据处理模块,负责数据加载和超图构建
- models/: 模型定义目录,包含HGNN网络结构和核心层实现
- utils/: 工具函数目录,提供超图处理等辅助功能
这种清晰的结构设计使得代码易于理解和扩展,即使是深度学习新手也能快速上手。
配置系统详解
HGNN的配置系统集中在config/config.yaml文件中,通过修改这个文件可以灵活调整模型训练的各种参数。主要配置项包括:
数据路径配置
data_root: &d_r /home/fengyifan/data/features modelnet40_ft: !join [*d_r, ModelNet40_mvcnn_gvcnn.mat] ntu2012_ft: !join [*d_r, NTU2012_mvcnn_gvcnn.mat]超图构建参数
graph_type: &g_t hypergraph K_neigs: [10] m_prob: 1.0 is_probH: True use_mvcnn_feature_for_structure: True use_gvcnn_feature_for_structure: True模型参数设置
on_dataset: &o_d ModelNet40 #on_dataset: &o_d NTU2012 use_mvcnn_feature: False use_gvcnn_feature: True n_hid: 128 drop_out: 0.5训练参数配置
max_epoch: 600 lr: 0.001 milestones: [100] gamma: 0.9 print_freq: 50 weight_decay: 0.0005通过调整这些参数,你可以控制数据加载、超图构建、模型结构和训练过程的各个方面。
数据加载与超图构建流程
HGNN的核心特色在于其对超图结构的处理能力。数据加载和超图构建的主要逻辑在datasets/data_helper.py中实现,通过load_feature_construct_H函数完成。
在train.py中,数据加载和超图构建的流程如下:
# 初始化数据 data_dir = cfg['modelnet40_ft'] if cfg['on_dataset'] == 'ModelNet40' \ else cfg['ntu2012_ft'] fts, lbls, idx_train, idx_test, H = \ load_feature_construct_H(data_dir, m_prob=cfg['m_prob'], K_neigs=cfg['K_neigs'], is_probH=cfg['is_probH'], use_mvcnn_feature=cfg['use_mvcnn_feature'], use_gvcnn_feature=cfg['use_gvcnn_feature'], use_mvcnn_feature_for_structure=cfg['use_mvcnn_feature_for_structure'], use_gvcnn_feature_for_structure=cfg['use_gvcnn_feature_for_structure']) G = hgut.generate_G_from_H(H)这个过程主要完成:
- 根据配置选择数据集
- 加载特征数据和标签
- 构建超图 incidence 矩阵 H
- 从超图生成 G 矩阵用于后续计算
超图构建是HGNN的关键步骤,它能够捕获数据中的高阶相关性,这也是HGNN相比传统图神经网络的优势所在。
HGNN模型结构解析
HGNN模型定义在models/HGNN.py中,核心网络结构如下:
model_ft = HGNN(in_ch=fts.shape[1], n_class=n_class, n_hid=cfg['n_hid'], dropout=cfg['drop_out'])HGNN模型包含以下几个关键部分:
- 输入层:接收节点特征
- 超图卷积层:实现超图上的信息传递
- 激活函数:引入非线性变换
- ** dropout层**:防止过拟合
- 输出层:产生最终分类结果
超图卷积层的实现细节在models/layers.py中,这是HGNN的核心创新点,能够有效处理超图结构中的高阶关系。
训练流程详解
HGNN的训练流程在train.py中实现,主要包含以下步骤:
1. 环境准备与参数初始化
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 数据转换到设备 fts = torch.Tensor(fts).to(device) lbls = torch.Tensor(lbls).squeeze().long().to(device) G = torch.Tensor(G).to(device) idx_train = torch.Tensor(idx_train).long().to(device) idx_test = torch.Tensor(idx_test).long().to(device)2. 模型、优化器和损失函数设置
optimizer = optim.Adam(model_ft.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) schedular = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones'], gamma=cfg['gamma']) criterion = torch.nn.CrossEntropyLoss()3. 训练循环实现
train_model函数实现了完整的训练循环,包括:
- 训练和验证阶段切换
- 前向传播和反向传播
- 损失计算和参数更新
- 模型保存和性能跟踪
4. 启动训练
model_ft = train_model(model_ft, criterion, optimizer, schedular, cfg['max_epoch'], print_freq=cfg['print_freq'])快速上手HGNN
要开始使用HGNN,只需按照以下步骤操作:
1. 克隆仓库
git clone https://gitcode.com/gh_mirrors/hgn/HGNN2. 安装依赖
安装PyTorch 0.4.0和yaml等依赖库,代码已在Python 3.6、Pytorch 0.4.0和CUDA 9.0环境下测试通过。
3. 配置数据集
下载所需的数据集特征文件:
- ModelNet40_mvcnn_gvcnn_feature
- NTU2012_mvcnn_gvcnn_feature
修改config/config.yaml中的data_root和result_root路径。
4. 调整参数
根据需要调整配置文件中的参数,如选择数据集、特征类型等:
# 选择数据集 on_dataset: &o_d ModelNet40 #on_dataset: &o_d NTU2012 # 选择特征 use_mvcnn_feature: False use_gvcnn_feature: True5. 启动训练
python train.py总结
HGNN通过创新的超图神经网络结构,为处理复杂数据的高阶相关性提供了强大工具。本文详细解析了HGNN的代码架构,包括配置系统、数据加载、模型结构和训练流程。通过本文的指南,你应该能够快速理解和使用HGNN进行节点分类等任务。
如果你对超图神经网络感兴趣,可以进一步研究models/layers.py中的超图卷积实现,或参考原始论文了解更多理论细节。
【免费下载链接】HGNNHypergraph Neural Networks (AAAI 2019)项目地址: https://gitcode.com/gh_mirrors/hgn/HGNN
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
