码迷,mamicode.com
首页 > Web开发 > 详细

Memory-based Graph Networks

时间:2020-03-31 19:16:57      阅读:167      评论:0      收藏:0      [点我收藏+]

标签:矩阵   perm   com   img   计算公式   全连接   查询   鼓励   代码   

论文:《Memory-based Graph Networks》,ICLR2020

代码:https://github.com/amirkhas/GraphMemoryNet

技术图片

概述

图神经网络(GNNs)是一类深度模型,可处理任意拓扑结构的数据。比如社交网络、知识图谱、分子结构等。GNNs通常被用来根据节点的交互关系学习节点的向量表示,典型的模型有gated GNN(Li et al., 2015)、MPNN(Giler et al., 2017)、GCN(Kipf & Welling, 2016)和GAT(Velikovi et al., 2018)。GNNs方法通常优于传统的随机游走、矩阵分解、核方法和概率图模型。

但是,这些模型无法学习到层次表示,因为它们没有利用图的组合性质。DiffPool (Ying et al., 2018)、TopKPool (Gao & Ji, 2019)、SAGPool (Lee et al., 2019)等模型引入参数化的图池化层,通过堆叠交错层和池化层来学习层次图表示。但这些模型的计算效率不高,因为它们需要在每个池化层后进行消息传递计算。

本论文介绍了一个能够同时进行图表示学习和节点聚类的记忆层,该记忆层由多组(multi-head)记忆键和卷积运算组成。记忆键被视为聚类中心,而卷积运算用来聚合多组结果。记忆层的输入叫做query,是前一层输出的节点表示,记忆层的输出是聚类后的节点表示。这种记忆层不显式依赖节点的连接信息,因此不存在过度平滑问题(Xu et al., 2018),同时也改进了效率和性能。

作者在论文中提出了两种基于记忆层的网络,分别叫做memory-based GNN(MemGNN)和graph memory network(GMN)。其中MemGNN就是首先使用GNN学习节点的初始表示然后堆叠记忆层学习层次表示;GMN则不依赖GNN,因此也不需要消息传递的计算。

相关工作

技术图片

方法

下面开始讲记忆层究竟是什么,以及由此而来的两种网络架构,即GMN和MemGNN。

记忆层

\(l\)层的记忆层可以表示为\(\mathcal{M}^{(l)}:\mathbb{R}^{n_l \times d_l} \longmapsto \mathbb{R}^{n_{l+1} \times d_{l+1}}\),记忆层输入\(n_l\)个维度为\(d_l\)的查询向量,生成\(n_{l+1}\)个维度为\(d_{l+1}\)的查询向量(下个记忆层的查询向量)。因为要自底向上学习图层次表示,要保证\(n_{l+1} \lt n_l\)

技术图片

上图就是记忆层的示意图,假设其中有\(|h|\)组记忆键。现在来看看记忆层是怎么实现聚类的。首先,假设第\(l\)层记忆层的输入为\(\mathbf{Q}^{(l)} \in \mathbb{R}^{n_l \times d_l}\),一组记忆键\(\mathbf{K}^{(l)} \in \mathbb{R}^{n_{l+1} \times d_l}\)可以看作是\(\mathbf{Q}^{(l)}\)的聚类中心。为了衡量\(\mathbf{Q}^{(l)}\)\(\mathbf{K}^{(l)}\)每个分量之间的相似度,作者借鉴Xie et al., 2016的工作,使用t分布作为核函数。因此查询\(q_i\)和记忆键\(k_j\)的正则化的相似度定义为:

\[C_{i,j}=\frac{(1+||q_i-k_j||^2/ \tau)^{-\frac{\tau + 1}{2}}}{\sum_{j^{‘}}(1+||q_i-k_{j^{‘}}||^2/ \tau)^{-\frac{\tau + 1}{2}}} \]

\(C_{i,j}\)就是将节点\(i\)分配到类簇\(j\)的概率,或者说\(q_i\)\(k_j\)之间的注意力权重。\(\tau\)是t分布的自由度。前面我们说到,记忆键总共有\(|h|\)组,因此实际上上述聚类要计算\(|h|\)次,得到结果为\([\mathbf{C}_0^{(l)} \dots \mathbf{C}_{|h|}^{(l)}] \in \mathbb{R}^{|h| \times n_{l+1} \times n_l}\)。为了将\(h\)组结果聚合为一组结果,作者将三个维度分别看作深度、高度和宽度,然后使用一个\(1 \times 1\)的卷积进行聚合:

\[\mathbf{C}^{(l)}=\text{softmax}(\Gamma_{\phi}(\Vert_{k=0}^{|h|}\mathbf{C}_k^{(l)})) \in \mathbb{R}^{n_l \times n_{l+1}} \]

其中,\(\Gamma_{\phi}\)\(1 \times 1\)的卷积,\(\mathbf{C}^{(l)}\)就是聚合后的分配矩阵。

之后,值(value)矩阵\(\mathbf{V}^{(l)} \in \mathbb{R}^{n_{l+1} \times d_l}\)由下式定义:

\[\mathbf{V}^{(l)} = \mathbf{C}^{(l)T}\mathbf{Q}^{(l)} \in \mathbb{R}^{n_{l+1} \times d_l} \]

由于\(\mathbf{V}^{(l)}\)元素维度和\(\mathbf{Q}^{(l)}\)元素维度相同,作者认为这就表示在相同空间对节点聚类,之后还要经过一个单层前向网络将\(\mathbf{V}^{(l)}\)投影为新的查询:

\[\mathbf{Q}^{(l+1)} = \sigma(\mathbf{V}^{(l)}\mathbf{W}) \in \mathbb{R}^{n_{l+1} \times d_{l+1}} \]

其中\(\sigma\)是LeankyReLU激活函数。\(\mathbf{Q}^{(l+1)}\)将作为下一个记忆层的查询。

对于图分类任务,我们可以通过堆叠记忆层最终获得整个图的向量表示,然后用全连接层进行分类:

\[\mathcal{Y}=\text{softmax}(\text{MLP}(\mathcal{M}^{(l)}(\mathcal{M}^{(l-1)}(\dots \mathcal{M}^{(0)}(\mathbf{Q}^{(0)}))))) \]

其中,\(\mathbf{Q}^{(0)}=f_q(g)\)是将图\(g\)输入网络\(f_g\)得到的初始查询表示,也就是初始节点向量。根据\(f_q\)的不同,作者引出了两种模型,即GMN和MemGNN。

GMN架构

GMN将图中节点表示视为排列不变(permutation-invariant)集,也就是不考虑它们之间的空间关系,因此也不需要使用到图神经网络中的消息传递机制。但是,图中节点毕竟是存在拓扑关系的,完全不考虑是行不通的,因此作者考虑的是把节点的拓扑关系编码到节点的初始表示中。更具体地说,作者使用带重启的随机游走(RWR)(Pan et al., 2004)来计算拓扑嵌入,然后按行对它们进行排序,以强制节点嵌入保持顺序不变。得到包含拓扑信息的节点表示\(\mathbf{X} \in \mathbb{R}^{n \times d_{in}}\)后,初始的查询表示通过两层前向网络计算得到:

\[\begin{aligned} \mathbf{Q}^{(0)} &=f_q(g) \&=\sigma([\sigma(\mathbf{SW}_0) \Vert X]\mathbf{W}_1) \end{aligned} \]

其中\(\mathbf{W}_0 \in \mathbb{R}^{n\times d_{in}}\)\(\mathbf{W}_1 \in \mathbb{R}^{2d_{in}\times d_{0}}\)是参数,\(\mathbf{S} \in \mathbb{R}^{n\times n}\)是图扩散矩阵,\(\Vert\)表示拼接操作,\(\sigma\)是LeakyReLU激活函数。

MemGNN架构

MemGNN直接使用图神经网络计算初始查询:

\[\begin{aligned} \mathbf{Q}^{(0)} &=f_q(g) \&=G_{\theta}(\mathbf{A},\mathbf{X}) \end{aligned} \]

其中,\(G_{\theta}\)是任意的图神经网络。作者在实现时使用了GAT模型的改进版e-GAT,也就是在计算注意力权重时考虑了边特征。注意力权重计算公式为:

\[\alpha_{ij}=\frac{\exp(\sigma(\mathbf{W}[\mathbf{W}_n h_i^{(l)} \Vert \mathbf{W}_n h_j^{(l)} \Vert \mathbf{W}_e h_{i \rightarrow j}^{(l)}]))}{\sum_{k \in \mathcal{N}_i}\exp(\sigma(\mathbf{W}[\mathbf{W}_n h_i^{(l)} \Vert \mathbf{W}_n h_k^{(l)} \Vert \mathbf{W}_e h_{i \rightarrow k}^{(l)}]))} \]

其中\(h_i^{(l)}, h_{i \rightarrow j}^{(l)}\)分别是节点表示和边表示,\(\mathbf{W}_n, \mathbf{W}_e\)分别是节点权重和边权重,\(\mathbf{W}\)是前向网络参数,\(\sigma\)是LeakyReLU激活函数。

模型训练

模型的损失包含两部分,有监督损失和无监督损失。有监督损失\(\mathcal{L}_{sup}\)来自图分类或者图回归损失。无监督损失用于鼓励模型学习利于聚类的表示,由\(\mathbf{C}^{(l)}\)和辅助分布\(\mathbf{P}^{(l)}\)之间的KL散度定义:

\[\begin{aligned} \mathcal{L}_{KL}^{(l)} &= KL(\mathbf{P}^{(l)}||\mathbf{C}^{(l)}) \&=\sum_i \sum_j P_{ij}^{(l)} \log \frac{P_{ij}^{(l)}}{C_{ij}^{(l)}} \end{aligned} \]

其中辅助分布\(\mathbf{P}^{(l)}\)的计算和Xie et al., 2016一样,

\[P_{ij}^{(l)} = \frac{(C_{ij}^{(l)})^2 / \sum_i C_{ij}^{(l)}}{\sum_{j^{‘}}(C_{ij^{‘}}^{(l)})^2 / \sum_i C_{ij^{‘}}^{(l)}} \]

因此模型最终的损失定义为

\[\mathbf{L} = \frac{1}{N}\sum_{n=1}^N\left(\lambda \mathcal{L}_{sup} + (1-\lambda)\sum_{l=1}^L \mathcal{L}_{KL}^{(l)} \right) \]

为了使训练更稳定,\(\mathcal{L}_{sup}\)产生的的梯度每个batch进行反向传播,而\(\mathcal{L}_{KL}^{(l)}\)产生的梯度每个epoch反向传播一次,可以通过反复调整\(\lambda\)的取值为0或1实现。这是因为快速地调整聚类中心,也就是记忆键,可能会导致训练不稳定。

实验

论文主要关注图分类和图回归任务,使用了5个图分类数据集和2个图回归数据集:

技术图片

主要实验结果如下面几幅图所示:

技术图片

技术图片

技术图片

Memory-based Graph Networks

标签:矩阵   perm   com   img   计算公式   全连接   查询   鼓励   代码   

原文地址:https://www.cnblogs.com/weilonghu/p/12607387.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!