GraphVite有两种方式,一种是使用Python接口,一种是使用命令行的方式。上两篇博客分别介绍了使用命令行和Python接口的方式。现在介绍下,
1.随机生成一个graph,这里使用networkx里的工具,生成一个BA无标度的图,并保存为edge_list格式,
import networkx as nx
G = nx.barabasi_albert_graph(100, 2)
nx.write_edgelist(G, 'data/test.edge_list', data=False, delimiter='\t')
2.首先创建一个yaml文件,这里采用test.yaml,内容如下。
https://download.csdn.net/download/dongfangxiaozi_/11953105
application:
graph
resource:
gpus: [0]
cpu_per_gpu: 8
dim: 128
format:
delimiters: " \t\n"
comment: "#"
graph:
file_name: /home/xxx/.graphvite/test.edgelist
as_undirected: true
build:
optimizer:
type: SGD
lr: 0.025
weight_decay: 0.005
num_partition: auto
num_negative: 1
batch_size: 100000
episode_size: 500
train:
model: LINE
num_epoch: 2000
negative_weight: 5
augmentation_step: 2
random_walk_length: 40
random_walk_batch_size: 100
log_frequency: 1000
save:
file_name: line_test.pkl
输入文件使用你本地的路径,这时候要把<>给略去。
输出模型的路径也进行修改。
修改输入数据的delimiter,即分隔符,一般为空格或者\t.
维度和模型,可以根据需要调整。
3.然后输入命令graphvite run test.yaml.
evaluate可以略去。
很快就完成了,会看到输出
model: LINE
optimizer: SGD
learning rate: 0.025, lr schedule: linear
weight decay: 0.005
#epoch: 2000, batch size: 100000
resume: no
positive reuse: 1, negative weight: 5
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Batch id: 0 / 3
loss = 0
4.这时候我们打开刚刚生成的表示学习向量
需要安装easydict
pip install easydict
import pickle
from easydict import EasyDict
file_name = '/home/xxx/.graphvite/line_test.pkl'
pickle.load(open(file_name, 'rb'), encoding='utf-8')
会输出:
{'context_embeddings': array([[-0.00822913, -0.00530941, -0.00695839, ..., 0.00390708,
0.00219729, -0.00652534],
[-0.01103182, -0.00352184, -0.0102704 , ..., 0.00561322,
0.00405805, -0.01209745],
[-0.01001213, -0.00452089, -0.00941589, ..., 0.00354438,
0.00312015, -0.00927031],
...,
[-0.00722677, -0.00073841, -0.00356204, ..., 0.00240901,
0.00181343, -0.00524651],
[-0.00879772, -0.00321459, -0.00741335, ..., 0.00506188,
0.00146051, -0.00558406],
[-0.00872622, -0.00296524, -0.00560616, ..., 0.00240658,
0.00134236, -0.00807665]], dtype=float32),
'id2name': ['0',
....,
'98'],
'vertex_embeddings': array([[ 0.01127131, 0.00733544, 0.00983142, ..., 0.00037273,
-0.00423259, 0.00497633],
[ 0.01100275, 0.00174217, 0.00472951, ..., -0.0049816 ,
-0.00484521, 0.00895557],
[ 0.01126207, 0.00034387, 0.00764804, ..., -0.00499946,
-0.00046137, 0.00998339],
...,
[ 0.00800165, 0.00556528, 0.01023772, ..., -0.00702143,
-0.00526442, 0.00745916],
[ 0.00895165, 0.00092912, 0.00540247, ..., -0.00289551,
0.00077499, 0.00426355],
[ 0.00849834, -0.00018509, 0.00910194, ..., -0.00081561,
-0.00055885, 0.00917041]], dtype=float32)}