当前位置: 首页 > 工具软件 > PyG > 使用案例 >

PyG(PyTorch Geometric)安装教程(附Cora数据集)

吕修筠
2023-12-01

PyG(PyTorch Geometric)安装教程(附Cora数据集)


PyG是多特蒙德工业大学(Technische University Dortmund)的Matthias Fey博士基于PyTorch框架提出的图卷积(Graph Convolution Net)神经网络框架,用于方便地编写和训练图神经网络(GNNs)。该框架由各种对图和其他不规则结构进行深度学习的方法组成,且提供了大量的公共基准数据集,可以进行任意图形、3D格网甚至是点云数据的学习和处理。官方网址


我们来介绍PyG在Windows环境下的安装。

首先,我们需要下载好对应版本的pytorch,这里不做过多赘述。

在对应环境中,我们可以直接使用以下conda命令进行安装:

conda install pyg -c pyg
# 或者
pip install torch_geometric

除此之外,我们需要安装PyG的依赖库:torch-sparestorch-scatter,根据PyTorch版本不同,安装的方式有所不同。注意查询Torch的版本。


PyTorch 1.11 版本✅

pip命令如下:

pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+${CUDA}.html

注意将${CUDA}换成自己的版本,否则会报错。

例如,我的PyTorch版本是:

# 通过pip list查看
# 或者python import torch print(torch.__version__)
torch                 1.11.0+cu113

那么,只需要做如下替换:

pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cu113.html

值得注意的是,cuxxx并不支持macOS。


PyTorch 1.10版本✅

pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html

其他版本的需要根据PyTorch版本指定依赖库的版本!

其他版本请参考


Cora数据集介绍

Cora数据集由2078篇机器学习领域的论文构成,每个样本点都是一篇论文,这些论文主要分为了7个类别,分别为基于案例、遗传算法、神经网络、概率方法、强化学习、规则学习与理论。在该数据集中,每篇论文都至少引用了该数据集中的另一篇论文,对每个节点所代表的论文,都由一个1433维的词向量表示,即该图上每个节点都具有1433个特征,词向量的每个元素都对应一个词,且该元素仅有0或1两个取值,取0表示该元素对应的词不在论文中,取1表示在。

Cora参数

  • ind.cora.x : 训练集节点特征向量,保存对象为:scipy.sparse.csr.csr_matrix,实际展开后大小为: (140, 1433)
  • ind.cora.tx : 测试集节点特征向量,保存对象为:scipy.sparse.csr.csr_matrix,实际展开后大小为: (1000, 1433)
  • ind.cora.allx : 包含有标签和无标签的训练节点特征向量,保存对象为:scipy.sparse.csr.csr_matrix,实际展开后大小为:(1708, 1433),可以理解为除测试集以外的其他节点特征集合,训练集是它的子集
  • ind.cora.y : one-hot表示的训练节点的标签,保存对象为:numpy.ndarray
  • ind.cora.ty : one-hot表示的测试节点的标签,保存对象为:numpy.ndarray
  • ind.cora.ally : one-hot表示的ind.cora.allx对应的标签,保存对象为:numpy.ndarray
  • ind.cora.graph : 保存节点之间边的信息,保存格式为:{ index : [ index_of_neighbor_nodes ] }
  • ind.cora.test.index : 保存测试集节点的索引,保存对象为:List,用于后面的归纳学习设置。

Cora下载

可以通过以下语句进行下载,当然,由于服务器和主机之间的物理距离,很可能会出现下载失败。

from torch_geometric.datasets import Planetoid
dataset=Planetoid(root=r"./Cora",name="Cora") # root: 指定路径 name: 数据集名称

我将数据集放在百度网盘上了哦,有需要的可以自行下载:

链接
提取码:UCAS

检查其是否成功:

将文件解压放在给定的root下即可。

print("网络数据包含的类数量:",dataset.num_classes)
print("网络数据边的特征数量:",dataset.num_edge_features)
print("网络数据边的数量:",dataset.data.edge_index.shape[1]/2) # 除以2是OOC的组织形式
print("网络数据节点的特征数量:",dataset.num_node_features)
print("网络数据节点的数量:",dataset.data.x.shape[0])

'''
网络数据包含的类数量: 7
网络数据边的特征数量: 0
网络数据边的数量: 5278.0
网络数据节点的特征数量: 1433
网络数据节点的数量: 2708
'''
 类似资料: