pytorch_geometric(pyg)复现T-GCN

东门俊智
2023-12-01

前言

上一篇文章从pyg提供的基本工具出发,介绍了pyg。但是大家用三方库,一般是将其作为积木来构建一个比较大的模型,把它用在自己的数据集上,而不是满足于跑跑demo里的简单模型和标准数据集。因此本文将从复现T-GCN(论文和官方源码见此)的角度出发,讲述怎么使用pyg搭建一个GNN-RNN模型,包括数据集的构建和模型的搭建。

刚开始复现的时候,我踩了很多坑,有的坑是因为不熟悉pyg踩的,有的坑是因为作者论文里和源码里的模型不一致踩的。这里说是复现,但是是以作者源码里的模型为准。那你可能会问了,都有源码了,我在这里吵吵啥“复现”呢?因为源码不是采用pyg写的,而是使用了原始的GCN计算方式,使用一些矩阵乘法做的。

前置准备

虽然我不是作者团队的……但是我觉得还是有必要介绍一下这篇文章的模型和数据集

T-GCN介绍

这里介绍的T-GCN的全称是Traffic-GCN,同学们有可能会在别的地方看到这个简称,但是不一定指的是这个模型。

T-GCN的核心模型结构是使用了GCN+GRU二者组合,先使用GCN得到更丰富的节点特征,再将每个节点的特征都送入GRU中进行计算。相当于使用GCN聚合空间特征,再使用GRU聚合时序特征,具体的计算公式如下。需要注意的是,每次输入的都是当前时间的特征和GRU的隐层特征,二者拼接后作为T-GCN-Cell(可以认为是T-GCN内部的一层卷积)的输入。

C o n v t ( X t , h t ) = L ⋅ c o n c a t ( X t , h t ) G R U ( X t , h t ) = u t = σ ( W u C o n v t ( X t , h t ) + b u ) r t = σ ( W r C o n v t ( X t , h t ) + b r ) c t = t a n h ( W c C o n v t ( X t , h t ) + b c ) h t + 1 = u t ∗ h t + ( 1 − u t ) ∗ c t Conv_t(X_{t},h_{t})=L\cdot concat(X_{t},h_{t})\\ GRU(X_{t},h_{t})= \begin{aligned} u_t & = \sigma(W_uConv_t(X_{t},h_{t})+b_u) \\ r_t & = \sigma(W_rConv_t(X_{t},h_{t})+b_r) \\ c_t & = tanh(W_cConv_t(X_{t},h_{t})+b_c)\\ h_{t+1} & = u_t*h_t+(1-u_t)*c_t \end{aligned}\\ Convt(Xt,ht)=Lconcat(Xt,ht)GRU(Xt,ht)=utrtctht+1=σ(WuConvt(Xt,ht)+bu)=σ(WrConvt(Xt,ht)+br)=tanh(WcConvt(Xt,ht)+bc)=utht+(1ut)ct

数据集介绍

这里只采用“shenzhen”数据集。该数据集是在深圳156条道路采集的交通流量数据,采集间隔为5分钟一次,维度为1。此外,还附带一个道路间的邻接矩阵。也就是作者在建模过程中,将道路当作节点,将道路是否连通作为建图标准。

开始复现

使用pyg复现T-GCN的过程是比较痛苦的,因为必须要削足适履。个人认为pyg对时序数据的支持似乎不那么友好,当然也有可能是因为我没找到适合时序数据使用的DataLoader和Data对象。

时序序列的GNN数据有什么问题?

在之前的pyg介绍文章就说过,DataLoader会将每个Data对象视为一个图,形成mini-batch时,将一个batch里的Data对象打包成一个大图。这在非时序样本时,没有任何问题,但是对于时序数据,每个样本中包含多个图,此时DataLoader打包出来的对象可能就不符合我们的心意了。

静态图

我们先假设一种最简单的情况,每个样本的图结构和连接关系完全相同(这也被称为“静态图”),因此我们给每个Data对象都是完全相同的邻接矩阵。为了追求并行化,我们通常把一个时间点的所有数据抽出来一起计算。假设初始时隐状态为零向量,T-GCN的大致伪码如下图所示。

# x.shape is [num_nodes, seq_len, num_features]
h = zeros()
for i in range(seq_len):
    h = gru(gcn(concat(x[:, i, :], h), edge_index))

这么乍一看,好像没问题,实际上也确实没问题。edge_index按照mini-batch的方式拼接成大图;x拼接之后形成了[batch_size*num_nodes, seq_len, num_features]的矩阵,每次取其中一个时间点进行运算,输入的逻辑非常正确。

动态图

然后我们再看,假如样本中的图结构关系(这里只考虑边变化的情况)可以随着时间动态变化,那对于每个时间点,都需要一张独立的图,此时Data对象规定的edge_index结构就不满足我们的要求了,与之配合的DataLoader也会拼接出错误的mini-batch。

对于这种情况,我前思后想,辗转反侧,想到了一个相当削足适履的方法,就是改造DataLoader生成mini-batch的函数,使之对样本中每一个时间点的邻接矩阵进行拼接,然后生成一个List,维度为[seq_len, 2, num_edges],同时需要注意,每张图的num_edges可能不一样,所以这样一个数据还没法打包到一个Tensor里,只能用List存下所有时间点的batch大图

不过幸好,T-GCN是静态图,没这么麻烦,这种情况只是自己在做磕盐的时候遇到的,如果大家有更好的方法,也欢迎讨论。

搭建削足适履的模型

DataSet

然后我们搭建DataSet对象

from typing import List, Union, Tuple

import numpy as np
import torch

from torch_geometric.data import InMemoryDataset, Dataset, Data
from utils.utils import dataset_path
from constant import DATASET_NAME_TRAFFIC
import pandas as pd

class TrafficDataSet(InMemoryDataset):
    # 一个点是15分钟
    seq_len = 4
    predict_len = 1
    DATASET_TYPE = 'sz'
    PROCESSED_DATASET_FILENAME = '%s_seq%d_pre%d' % (DATASET_TYPE, seq_len, predict_len)
    speed_name = DATASET_TYPE + '_speed.csv'
    adj_name = DATASET_TYPE + '_adj.csv'

    def __init__(self):
        super().__init__(root=dataset_path(DATASET_NAME_TRAFFIC))
        self.data, self.slices, self.max_speed, self.num_nodes, self.seq_len, self.pre_len = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        return [TrafficDataSet.speed_name, TrafficDataSet.adj_name]

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        return TrafficDataSet.PROCESSED_DATASET_FILENAME + '.pt'

    def download(self):
        pass

    def process(self):
        # 一个文件里是有抬头的,一个没有
        speed = pd.read_csv(self.raw_paths[0]).values
        adj = pd.read_csv(self.raw_paths[1], header=None).values

        num_nodes = len(adj)

        adj = process_adj(adj)
        # 对样本的输出进行归一化,归一化参数需要记录下来,计算测试集MSE时要用
        max_speed = np.max(speed)
        speed = speed / max_speed

        speed = torch.tensor(speed, dtype=torch.float32)
        adj = torch.tensor(adj, dtype=torch.int64)

        time_len = speed.shape[0]
        seq_len = TrafficDataSet.seq_len
        pre_len = TrafficDataSet.predict_len
        data_list = []
        for i in range(time_len - seq_len - pre_len):
            # speed = [time_len, num_nodes]
            # x = [num_nodes, seq_len, num_features=1]
            x = speed[i: i + seq_len].transpose(0,1).reshape([num_nodes, seq_len, 1])

            # y = [pre_len, num_nodes] -> [num_nodes, pre_len]
            y = speed[i + seq_len: i + seq_len + pre_len].transpose(0, 1)

            pyg_data = Data(x, edge_index=adj, y=y)
            data_list.append(pyg_data)

        data, slices = self.collate(data_list)
        torch.save((data, slices, max_speed, num_nodes, seq_len, pre_len), self.processed_paths[0])

# 数据集给的是邻接矩阵,需要转换成pyg接受的稀疏矩阵的形式
def process_adj(adj):
    node_cnt = len(adj)
    pyg_adj = [[],[]]
    for i in range(node_cnt):
        for j in range(node_cnt):
            if adj[i][j] == 1:
                pyg_adj[0].append(i)
                pyg_adj[1].append(j)
    return np.array(pyg_adj)

T-GCN模型

模型听起来也不复杂,因此对照着源码直接开始搭建

import torch
from torch_geometric.nn.conv import GCNConv

import torch.nn.functional as F
class TGCN_Conv_Module(torch.nn.Module):
    def __init__(self, args):
        super(TGCN_Conv_Module, self).__init__()
        self.args = args

        self.num_features = args.c_in
        self.nhid = args.c_out

        # 卷积层将输入与GRU的hidden_state拼接起来作为输入,输出hidden_size的特征
        self.conv1 = GCNConv(self.num_features+self.nhid, self.nhid)

    def forward(self, x, edge_index):
        # 实际上作者源码中只使用了一层GCN卷积,而论文中是两层
        x = F.relu(self.conv1(x, edge_index))
        x = torch.sigmoid(x)

        return x


class TGCNCell(torch.nn.Module):
    def __init__(self, args):
        super(TGCNCell, self).__init__()
        self.args = args
        self.num_features = args.c_in
        self.nhid = args.c_out
        self.seq_len = args.seq_len
        self.num_nodes = args.num_nodes

        # 这是仿照作者源码里的写法,实际上这是两个GCN,在forward函数中会将其输出拆成两半
        self.graph_conv1 = GCNConv(self.nhid+self.num_features, self.nhid * 2)
        self.graph_conv2 = GCNConv(self.nhid+self.num_features, self.nhid)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.constant_(self.graph_conv1.bias, 1.0)

    def forward(self, x, edge_index, hidden_state):

        ru_input = torch.concat([x, hidden_state], dim=1)

        # 这里将一个GCN的输出拆成两半,如果熟悉其矩阵写法的话,实际上就是用了俩GCN
        # 但是这里的拆分函数也是仿照源码,个人觉得拆分的维度不对,但是这么写的准确率高
        ru = torch.sigmoid(self.graph_conv1(ru_input, edge_index))
        r, u = torch.chunk(ru.reshape([-1, self.num_nodes * 2 * self.nhid]), chunks=2, dim=1)
        r = r.reshape([-1, self.nhid])
        u = u.reshape([-1, self.nhid])

        c_input = torch.concat([x, r * hidden_state], dim=1)
        c = torch.tanh(self.graph_conv2(c_input, edge_index))

        new_hidden_state = u * hidden_state + (1.0 - u) * c
        return new_hidden_state

# 先进行图级别聚合,再进行序列建模
class RNNProcessHelper(torch.nn.Module):
    def __init__(self, args, rnn_cell):
        super(RNNProcessHelper, self).__init__()
        self.args = args
        self.num_features = args.c_in
        self.nhid = args.c_out
        self.out_dim = args.out_dim
        self.seq_len = args.seq_len
        self.num_nodes = args.num_nodes

        self.rnn_cell = rnn_cell

    def forward(self, data, hidden_state=None):
        x, edge_index = data.x, data.edge_index
        if type(edge_index) is torch.Tensor:
            is_seq_edge_index = False
        elif type(edge_index) is list:
            is_seq_edge_index = True
        else:
            raise '没有边连接信息!'

        if not hidden_state:
            hidden_state = torch.zeros([x.shape[0], self.nhid]).to(self.args.device)

        hidden_state_list = []
        for i in range(self.seq_len):
            # return gru_output.shape = [batch_size*num_nodes, hidden_size]
            if is_seq_edge_index:
                hidden_state = self.rnn_cell(x[:, i, :], edge_index[i], hidden_state)
            else:
                hidden_state = self.rnn_cell(x[:, i, :], edge_index, hidden_state)
            hidden_state_list.append(hidden_state)

        return hidden_state_list

# 回归任务
class TGCN_Reg_Net(torch.nn.Module):
    def __init__(self, args):
        super(TGCN_Reg_Net, self).__init__()
        self.args = args
        self.num_features = args.c_in
        self.nhid = args.c_out
        self.out_dim = args.out_dim
        self.seq_len = args.seq_len
        self.num_nodes = args.num_nodes

        # self.tgcn_cell = TGCN_Cell(args)
        tgcn_cell = TGCNCell(args)
        self.seq_process_helper = RNNProcessHelper(args, tgcn_cell)

        # 将每个节点最终的hidden_state -> 该节点未来3小时的车速
        self.lin_out = torch.nn.Linear(self.nhid, self.out_dim)

    def forward(self, data):
        hidden_state_list = self.seq_process_helper(data)

        # 选最后一个output,用于预测
        hidden_state_last = hidden_state_list[-1]
        out = self.lin_out(hidden_state_last)

        # 按照数据集的构建方式,[batch*num_nodes, out_dim]
        return out

    @staticmethod
    def test(model, loader, args) -> float:
        import math
        model.eval()
        loss = 0.0
        max_speed = args.max_speed
        # 因为batch=1,所以一次是算一个样本的mse
        for data in loader:
            data = data.to(args.device)
            out = model(data)
            loss += F.mse_loss(out, data.y).item()

        mse_loss = loss / len(loader.dataset)
        rmse_loss = math.sqrt(mse_loss) * max_speed
        # print("val RMSE loss:{}".format(rmse_loss))

        return rmse_loss

    @staticmethod
    def get_loss_function():
        from utils.loss_utils import mse_loss
        return mse_loss

主函数

import math

import torch
from torch_geometric.loader import DataLoader

from utils.dataset_utils import split_dataset_by_ratio
from classfiers.tgcn import TGCN_Reg_Net
from datasets.traffic import TrafficDataSet
from utils.args_utils import get_args_pred
from utils.task_utils import train

if __name__ == '__main__':

    dataset = TrafficDataSet()
    train_set, test_set = split_dataset_by_ratio(dataset)

    args = get_args_pred(dataset)

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    model = TGCN_Reg_Net(args).to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    train(model, train_loader, test_loader, optimizer, args)

整个模型的大致搭建过程就是这样,还有一些工具函数没有给出,但是看名字也大致能知道是干什么的

复现中碰到的其他问题

在我复现的过程中,好不容易完成了模型的搭建,看起来也和源码里的一模一样了,但是最终的指标差很多。然后尝试着看了看源码中对超参数的规定,发现有一个很微妙的参数,weight_decay,作者并不是将其加在了优化器Adam的构造函数中(也就是令Adam的weight_decay=0,而是在计算loss时,在mse_loss的基础上,加上了模型参数的l2正则化损失,其计算方式如下

def regular_loss(model, lamda=0):
    reg_loss = 0.0
    for param in model.parameters():
        reg_loss += torch.sum(param ** 2)
    return lamda * reg_loss

def mse_loss(out, label, model, reg_weight=0):
    classify_loss = F.mse_loss(out.squeeze(), label.squeeze())
    reg_loss = regular_loss(model, reg_weight)
    return classify_loss + reg_loss

后来查阅一些资料发现,这是因为Adam对模型的惩罚力度也会随着模型的训练进行自适应调整,使用AdamW可以解决这一问题,然而实际上也没什么卵用。因此,对于我这种半桶水来说,还是老老实实用作者调出来的参数吧……

后记

这里简要介绍了自己复现T-GCN的过程,把模型和DataSet的构建过程贴了出来。现在暂时没有整理可以直接运行的源码供大家下载,因为作者已经开放了源码,而我只不过是拿pyg重新实现了一下,对于学习pyg本身没有太大的用处,对于学习T-GCN也没有太大的用处。

 类似资料: