汀丶

V1

2022/12/02阅读:20主题:默认主题

12.PGL图学习之项目实践(UniMP算法实现论文节点分类、新冠疫苗项目实战,助力疫情)[系列九]

原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5100049?contributionType=1

1.图学习技术与应用

图是一个复杂世界的通用语言,社交网络中人与人之间的连接、蛋白质分子、推荐系统中用户与物品之间的连接等等,都可以使用图来表达。图神经网络将神经网络运用至图结构中,可以被描述成消息传递的范式。百度开发了PGL2.2,基于底层深度学习框架paddle,给用户暴露了编程接口来实现图网络。与此同时,百度也使用了前沿的图神经网络技术针对一些应用进行模型算法的落地。本次将介绍百度的PGL图学习技术与应用。

1.1图来源与建模

首先和大家分享下图学习主流的图神经网络建模方式。

14年左右开始,学术界出现了一些基于图谱分解的技术,通过频域变换,将图变换至频域进行处理,再将处理结果变换回空域来得到图上节点的表示。后来,空域卷积借鉴了图像的二维卷积,并逐渐取代了频域图学习方法。图结构上的卷积是对节点邻居的聚合。

基于空间的图神经网络主要需要考虑两个问题:

  • 怎样表达节点特征;

  • 怎样表达一整张图。

第一个问题可以使用邻居聚合的方法,第二问题使用节点聚合来解决。

目前大部分主流的图神经网络都可以描述成消息传递的形式。需要考虑节点如何将消息发送至目标节点,然后目标节点如何对收到的节点特征进行接收。

1.2 PGL2.2回顾介绍

PGL2.2基于消息传递的思路构建整体框架。PGL最底层是飞浆核心paddle深度学习框架。在此之上,搭建了CPU图引擎和GPU上进行tensor化的图引擎,来方便对图进行如图切分、图存储、图采样、图游走的算法。再上一层,会对用户暴露一些编程接口,包括底层的消息传递接口和图网络实现接口,以及高层的同构图、异构图的编程接口。框架顶层会支持几大类图模型,包括传统图表示学习中的图游走模型、消息传递类模型、知识嵌入类模型等,去支撑下游的应用场景。

最初的PGL是基于paddle1.x的版本进行开发的,所以那时候还是像tensorflow一样的静态图模式。目前paddle2.0已经进行了全面动态化,那么PGL也相应地做了动态图的升级。现在去定义一个图神经网络就只需要定义节点数量、边数量以及节点特征,然后将图tensor化即可。可以自定义如何将消息进行发送以及目标节点如何接收消息。

上图是使用PGL构建一个GAT网络的例子。最开始会去计算节点的权重,在发送消息的时候GAT会将原节点和目标节点特征进行求和,再加上一个非线性激活函数。在接收的时候,可以通过reduce_softmax对边上的权重进行归一化,再乘上hidden state进行加权求和。这样就可以很方便地实现一个GAT网络。

对于图神经网络来讲,在构建完网络后,要对它进行训练。训练方式和一般机器学习有所不同,需要根据图的规模选择适用的训练方案。

例如在小图,即图规模小于GPU显存的情况下,会使用full batch模式进行训练。它其实就是把一整张图的所有节点都放置在GPU上,通过一个图网络来输出所有点的特征。它的好处在于可以跑一个很深的图。这一训练方案会被应用于中小型数据集,例如Cora、Pubmed、Citeseer、ogbn-arxiv等。最近在ICML上发现了可以堆叠至1000层的图神经网络,同样也是在这种中小型数据集上做评估。

对于中等规模的图,即图规模大于GPU单卡显存,知识可以进行分片训练,每一次将一张子图塞入GPU上。PGL提供了另一个方案,使用分片技术来降低显存使用的峰值。例如对一个复杂图进行计算时,它的计算复杂度取决于边计算时显存使用的峰值,此时如果有多块GPU就可以把边计算进行分块,每台机器只负责一小部分的计算,这样就可以大大地减少图神经网络的计算峰值,从而达到更深的图神经网络的训练。分块训练完毕后,需要通过NCCL来同步节点特征。

在PGL中,只需要一行DistGPUGraph命令就可以在原来full batch的训练代码中加入这样一个新特性,使得可以在多GPU中运行一个深层图神经网络。例如在obgn-arxiv中尝试了比较复杂的TransformerConv网络,如果使用单卡训练一个三层网络,其GPU显存会被占用近30G,而使用分片训练就可以将它的显存峰值降低。同时,还实现了并行的计算加速,例如原来跑100 epoch需要十分钟,现在只需要200秒。

在大图的情况下,又回归到平时做数据并行的mini batch模式。Mini batch与full batch相比最主要的问题在于它需要做邻居的采样,而邻居数目的提升会对模型的深度进行限制。这一模式适用于一些巨型数据集,包括ogbn-products和ogbn-papers100m。

发现PyG的作者的新工作GNNAutoScale能够把一个图神经网络进行自动的深度扩展。它的主要思路是利用CPU的缓存技术,将邻居节点的特征缓存至CPU内存中。当训练图网络时,可以不用实时获取所有邻居的最新表达,而是获取它的历史embedding进行邻居聚合计算。实验发现这样做的效果还是不错的。

在工业界的情况下可能会存在更大的图规模的场景,那么这时候可能单CPU也存不下如此图规模的数据,这时需要一个分布式的多机存储和采样。PGL有一套分布式的图引擎接口,使得可以轻松地在MPI以及K8S集群上通过PGL launch接口进行一键的分布式图引擎部署。目前也支持不同类型的邻居采样、节点遍历和图游走算法。

整体的大规模训练方式包括一个大规模分布式图引擎,中间会包含一些图采样的算子和神经网络的开发算子。顶层针对工业界大规模场景,往往需要一个parameter server来存储上亿级别的稀疏特征。借助paddlefleet的大规模参数服务器来支持超大规模的embedding存储。

1.3 图神经网络技术

1.3.1 节点分类任务

在算法上也进行了一些研究。图神经网络与一般机器学习场景有很大的区别。一般的机器学习假设数据之间独立同分布,但是在图网络的场景下,样本是有关联的。预测样本和训练样本有时会存在边关系。通常称这样的任务为半监督节点分类问题。

解决节点分类问题的传统方法是LPA标签传播算法,考虑链接关系以及标签之间的关系。另外一类方法是以GCN为代表的特征传播算法,只考虑特征与链接的关系。

通过实验发现在很多数据集下,训练集很难通过过拟合达到99%的分类准确率。也就是说,训练集中的特征其实包含很大的噪声,使得网络缺乏过拟合能力。所以,想要显示地将训练label加入模型,因为标签可以消减大部分歧义。在训练过程中,为了避免标签泄露,提出了UniMP算法,把标签传播和特征传播融合起来。这一方法在三个open graph benchmark数据集上取得了SOTA的结果。

后续还把UniMP应用到更大规模的KDDCup 21的比赛中,将UniMP同构算法做了异构图的拓展,使其在异构图场景下进行分类任务。具体地,在节点邻居采样、批归一化和注意力机制中考虑节点之间的关系类型。

1.3.2 链接预测任务

第二个比较经典的任务是链接预测任务。目前很多人尝试使用GNN与link prediction进行融合,但是这存在两个瓶颈。首先,GNN的深度和邻居采样的数量有关;其次,当训练像知识图谱的任务时,每一轮训练都需要遍历训练集的三元组,此时训练的复杂度和邻居节点数量存在线性关系,这就导致了如果邻居比较多,训练一个epoch的耗时很长。

借鉴了最近基于纯特征传播的算法,如SGC等图神经网络的简化方式,提出了基于关系的embedding传播。发现单独使用embedding进行特征传播在知识图谱上是行不通的。因为知识图谱上存在复杂的边关系。所以,根据不同关系下embedding设计了不同的score function进行特征传播。此外,发现之前有一篇论文提出了OTE的算法,在图神经网络上进行了两阶段的训练。

使用OGBL-WikiKG2数据集训练OTE模型需要超过100个小时,而如果切换到的特征传播算法,即先跑一次OTE算法,再进行REP特征传播,只需要1.7个小时就可以使模型收敛。所以REP带来了近50倍的训练效率的提升。还发现只需要正确设定score function,大部分知识图谱算法使用的特征传播算法都会有效果上的提升;不同的算法使用REP也可以加速它们的收敛。

将这一套方法应用到KDDCup 21 Wiki90M的比赛中。为了实现比赛中要求的超大规模知识图谱的表示,做了一套大规模的知识表示工具Graph4KG,最终在KDDCup中取得了冠军。

1.4 算法应用落地

PGL在百度内部已经进行了广泛应用。包括百度搜索中的网页质量评估,会把网页构成一个动态图,并在图上进行图分类的任务。百度搜索还使用PGL进行网页反作弊,即对大规模节点进行检测。在文本检索应用中,尝试使用图神经网络与自然语言处理中的语言模型相结合。在其他情况下,的落地场景有推荐系统、风控、百度地图中的流量预测、POI检索等。

本文以推荐系统为例,介绍一下平时如何将图神经网络在应用中进行落地。

推荐系统常用的算法是基于item-based和user-based协同过滤算法。Item-based协同过滤就是推荐和item相似的内容,而user-based 就是推荐相似的用户。这里最重要的是如何去衡量物品与物品之间、用户与用户之间的相似性。

可以将其与图学习结合,使用点击日志来构造图关系(包括社交关系、用户行为、物品关联),然后通过表示学习构造用户物品的向量空间。在这个空间上就可以度量物品之间的相似性,以及用户之间的相似性,进而使用其进行推荐。

常用的方法有传统的矩阵分解方法,和阿里提出的基于随机游走 + Word2Vec的EGES算法。近几年兴起了使用图对比学习来获得节点表示。

在推荐算法中,主要的需求是支持复杂的结构,支持大规模的实现和快速的实验成本。希望有一个工具包可以解决GNN + 表示学习的问题。所以,对现有的图表示学习算法进行了抽象。具体地,将图表示学习分成了四个部分。第一部分是图的类型,将其分为同构图、异构图、二部图,并在图中定义了多种关系,例如点击关系、关注关系等。第二,实现了不同的样本采样的方法,包括在同构图中常用的node2Vec以及异构图中按照用户自定义的meta path进行采样。第三部分是节点的表示。可以根据id去表示节点,也可以通过图采样使用子图来表示一个节点。还构造了四种GNN的聚合方式。

发现不同场景以及不同的图表示的训练方式下,模型效果差异较大。所以的工具还支持大规模稀疏特征side-info的支持来进行更丰富的特征组合。用户可能有很多不同的字段,有些字段可能是缺失的,此时只需要通过一个配置表来配置节点包含的特征以及字段即可。还支持GNN的异构图自动扩展。你可以自定义边关系,如点击关系、购买关系、关注关系等,并选取合适的聚合方式,如lightgcn,就可以自动的对GNN进行异构图扩展,使lightgcn变为relation-wise的lightgcn。

对工具进行了瓶颈分析,发现它主要集中在分布式训练中图采样和负样本构造中。可以通过使用In-Batch Negative的方法进行优化,即在batch内走负采样,减少通讯开销。这一优化可以使得训练速度提升四至五倍,而且在训练效果上几乎是无损的。此外,在图采样中可以通过对样本重构来降低采样的次数,得到两倍左右的速度提升,且训练效果基本持平。相比于市面上现有的分布式图表示工具,还可以实现单机、双机、四机甚至更多机器的扩展。

不仅如此,还发现游走类模型训练速度较快,比较适合作为优秀的热启动参数。具体地,可以先运行一次metapath2Vce算法,将训练得到的embedding作为初始化参数送入GNN中作为热启动的节点表示。发现这样做在效果上有一定的提升。

1.5 Q&A

Q1:在特征在多卡之间传递的训练模式中,使用push和pull的方式通讯时间占比大概有多大?

A:通讯时间的占比挺大的。如果是特别简单的模型,如GCN等,那么使用这种方法训练,通讯时间甚至会比直接跑这个模型的训练时间还要久。所以这一方法适合复杂模型,即模型计算较多,且通讯中特征传递的数据量相比来说较小,这种情况下就比较适合这种分布式计算。

Q2:图学习中节点邻居数较多会不会导致特征过平滑?

A:这里采用的方法很多时候都很暴力,即直接使用attention加多头的机制,这样会极大地减缓过平滑问题。因为使用attention机制会使得少量特征被softmax激活;多头的方式可以使得每个头学到的激活特征不一样。所以这样做一定比直接使用GCN进行聚合会好。

Q3:百度有没有使用图学习在自然语言处理领域的成功经验?

A:之前有类似的工作,你可以关注ERINESage这篇论文。它主要是将图网络和预训练语言模型进行结合。也将图神经网络落地到了例如搜索、推荐的场景。因为语言模型本身很难对用户日志中包含的点击关系进行建模,通过图神经网络就可以将点击日志中的后验关系融入语言模型,进而得到较大的提升。

Q4:能详细介绍一下KDD比赛中将同构图拓展至异构图的UniMP方法吗?

A:首先,每一个关系类型其实应该有不同的邻居采样方法。例如paper到author的关系,会单独地根据它来采样邻居节点。如果按照同构图的方式来采样,目标节点的邻居节点可能是论文,也可能是作者或者机构,那么采样的节点是不均匀的。其次,在批归一化中按照关系channel来进行归一化,因为如果你将paper节点和author节点同时归一化,由于它们的统计均值和方差不一样,那么这种做法会把两者的统计量同时带骗。同理,在聚合操作中,不同的关系对两个节点的作用不同,需要按照不同关系使用不同的attention注意力权重来聚合特征。

2.基于UniMP算法实现论文引用网络节点分类任务

图学习之基于PGL-UniMP算法的论文引用网络节点分类任务:https://aistudio.baidu.com/aistudio/projectdetail/5116458?contributionType=1

由于文章篇幅问题,为了让学习者有更好的体验,这里新开一个项目完成这个任务。

Epoch 987 Train Acc 0.7554459 Valid Acc 0.7546095
Epoch 988 Train Acc 0.7537374 Valid Acc 0.75717235
Epoch 989 Train Acc 0.75497127 Valid Acc 0.7573859
Epoch 990 Train Acc 0.7611409 Valid Acc 0.75653166
Epoch 991 Train Acc 0.75316787 Valid Acc 0.75489426
Epoch 992 Train Acc 0.749561 Valid Acc 0.7547519
Epoch 993 Train Acc 0.7571544 Valid Acc 0.7551079
Epoch 994 Train Acc 0.7516492 Valid Acc 0.75581974
Epoch 995 Train Acc 0.7563476 Valid Acc 0.7563181
Epoch 996 Train Acc 0.7504627 Valid Acc 0.7538976
Epoch 997 Train Acc 0.7476152 Valid Acc 0.75439596
Epoch 998 Train Acc 0.7539272 Valid Acc 0.7528298
Epoch 999 Train Acc 0.7532153 Valid Acc 0.75396883

3.新冠疫苗项目实战,助力疫情

Kaggle新冠疫苗研发竞赛:https://www.kaggle.com/c/stanford-covid-vaccine/overview

mRNA疫苗已经成为2019冠状病毒最快的候选疫苗,但目前它们面临着关键的潜在限制。目前最大的挑战之一是如何设计超稳定的RNA分子(mRNA)。传统疫苗是装在注射器里通过冷藏运输到世界各地,但mRNA疫苗目前还不可能做到这一点。

研究人员已经观察到RNA分子有降解的倾向。这是一个严重的限制,降解会使mRNA疫苗失效。目前,对于特定RNA的主干中哪个部位最容易受影响的细节知之甚少。在不了解这些情况的情况下,目前针对COVID-19的mRNA疫苗必须在高度冷藏条件下准备和运输,它们必须能够得到稳定,否则不太可能送达地球上的每个人。

由斯坦福大学医学院(Stanford’s School of Medicine)计算生物学家瑞朱·达斯(Rhiju Das)教授领导的永恒星系(Eterna)社区将科学家和竞赛玩家聚集在一起,解决谜题并发明药物。Eterna是一款在线竞赛平台,通过谜题挑战玩家解决诸如mRNA设计等科学问题。由斯坦福大学的研究人员合成并进行实验测试,以获得关于RNA分子的新见解。Eterna社区之前已经开启了新的科学原理,对致命疾病做出了新的诊断,并利用世界上最强大的智力资源改善公众生活。Eterna社区通过其在20多份出版物上的贡献推动了生物技术,包括RNA生物技术进展。

在这次竞赛中,我们希望利用Kaggle社区的数据科学专业知识来开发模型和设计RNA降解规则。模型将预测RNA分子每个碱基的可能降解率,训练的对象是由超过3000个RNA分子组成的Eterna数据集子集(它们跨越了一整套序列和结构),以及它们在每个位置的降解率。然后,我们将根据Eterna玩家刚刚为COVID-19 mRNA疫苗设计的第二代RNA序列为模型评分。这些最终的测试序列目前正在合成和实验表征在斯坦福大学与建模工作并行——自然将评分模型!

提高mRNA疫苗的稳定性已经在探索,我们必须解决这一深刻的科学挑战,以加速mRNA疫苗研究,并提供一种针对COVID-19背后病毒SARS-CoV-2的冰箱稳定疫苗。我们正在试图解决的问题希望得到学术实验室、工业研发团队和超级计算机的帮助,你可以加入电子竞赛玩家、科学家和开发者的团队,在Eterna永恒星球上对抗这一毁灭性病毒。

3.1案例简介

将编码的DNA送到细胞中,细胞使用mRNA(Messenger RNA)组装蛋白,免疫系统检测到组装蛋白质以后,利用构建病毒蛋白的编码基因激活免疫系统产生抗体,增强针对冠状病毒的抵御能力。

不同的mRNA生成同一个蛋白质,

mRNA随着时间的流逝及温度的变化发生了降解,

如何找到结构更加稳定的mRNA?利用图神经网络找到更稳定的mRNA,颜色越深越稳定.

3.2 新冠疫苗项目拔高实战

数据分布特征

查看当前挂载的数据集目录

# 加载一些需要用到的模块,设置随机数
import json
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import networkx as nx

from utils.config import prepare_config, make_dir
from utils.logger import prepare_logger, log_to_file
from data_parser import GraphParser

seed = 123
np.random.seed(seed)
random.seed(seed)

# https://www.kaggle.com/c/stanford-covid-vaccine/data
# 加载训练用的数据
df = pd.read_json('../data/data179441/train.json', lines=True)
# 查看一下数据集的内容
sample = df.loc[0]
print(sample)

index                                                                400
id                                                          id_2a7a4496f
sequence               GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...
structure              .....(((...)))((((((((((((((((((((.((((....)))...
predicted_loop_type    EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...
signal_to_noise                                                        0
SN_filter                                                              0
seq_length                                                           107
seq_scored                                                            68
reactivity_error       [146151.225, 146151.225, 146151.225, 146151.22...
deg_error_Mg_pH10      [104235.1742, 104235.1742, 104235.1742, 104235...
deg_error_pH10         [222620.9531, 222620.9531, 222620.9531, 222620...
deg_error_Mg_50C       [171525.3217, 171525.3217, 171525.3217, 171525...
deg_error_50C          [191738.0886, 191738.0886, 191738.0886, 191738...
reactivity             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_pH10            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_pH10               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_50C             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_50C                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
Name: 0, dtype: object

例如 deg_50C、deg_Mg_50C 这样的值全为0的行,就是我们需要预测的。

structure一行,数据中的括号是为了构成边用的。

本案例要预测RNA序列不同位置的降解速率,训练数据中提供了多个ground值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50

  • reactivity - (1x68 vector 训练集,1x91测试集) 一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定RNA样本可能的二级结构。

  • deg_Mg_pH10 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高pH (pH 10)下的降解可能性。

  • deg_Mg_50 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高温(50摄氏度)下的降解可能性。

# 利用GraphParser构造图结构的数据
args = prepare_config("./config.yaml", isCreate=False, isSave=False)
parser = GraphParser(args) # GraphParser类来自data_parser.py
gdata = parser.parse(sample) # GraphParser里最主要的函数就是parse(self, sample)

数据格式:

{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        ...,
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 'edges': array([[  0,   1],
        [  1,   0],
        [  1,   2],
        ...,
        [142, 105],
        [106, 142],
        [142, 106]]),
 'efeat': array([[ 0.,  0.,  0.,  1.,  1.],
        [ 0.,  0.,  0., -1.,  1.],
        [ 0.,  0.,  0.,  1.,  1.],
        ...,
        [ 0.,  1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.]], dtype=float32),
 'labels': array([[ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        ...,
        [ 0.    ,  0.9213,  0.    ],
        [ 6.8894,  3.5097,  5.7754],
        [ 0.    ,  1.8426,  6.0642],
          ...,        
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]], dtype=float32),
 'mask': array([[ True],
        [ True],
     ......
       [False]])}
# 图数据可视化
fig = plt.figure(figsize=(24, 12))
nx_G = nx.Graph()
nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])

nx_G.add_edges_from(gdata['edges'])
node_color = ['g' for _ in range(sample['seq_length'])] + \
['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]
options = {
    "node_color": node_color,
}
pos = nx.spring_layout(nx_G, iterations=400, k=0.2)
nx.draw(nx_G, pos, **options)

plt.show()

Snipaste_2022-11-25_20-53-23.jpg
Snipaste_2022-11-25_20-53-23.jpg
从图中可以看到,绿色节点是碱基,黄色节点是密码子。



结果返回的是 MCRMSE 和 loss

{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}

[DEBUG] 2022-11-25 17:50:42,468 [ trainer.py: 66]: {'MCRMSE': 0.5496759, 'loss': 0.3025484172316889} [DEBUG] 2022-11-25 17:50:42,468 [ trainer.py: 73]: write to tensorboard ../checkpoints/covid19/eval_history/eval [DEBUG] 2022-11-25 17:50:42,469 [ trainer.py: 73]: write to tensorboard ../checkpoints/covid19/eval_history/eval [INFO] 2022-11-25 17:50:42,469 [ trainer.py: 76]: [Eval:eval]:MCRMSE:0.5496758818626404 loss:0.3025484172316889 [INFO] 2022-11-25 17:50:42,602 [monitored_executor.py: 606]: ********** Stop Loop ************ [DEBUG] 2022-11-25 17:50:42,607 [monitored_executor.py: 199]: saving step 12500 to ../checkpoints/covid19/model_12500


这部分代码实现参考项目:[PGL图学习之基于GNN模型新冠疫苗任务[系列九]](https://aistudio.baidu.com/aistudio/projectdetail/5123296?contributionType=1)

# 我们在 layer.py 里定义了一个新的 gnn 模型(my_gnn),消息传递的过程中加入了边的特征(edge_feat)
# 然后修改 model.py 里的 GNNModel
# 使用修改后的模型,运行 main.py。为节省时间,设置 epochs = 100

# !python main.py --config config.yaml #训练
#!python main.py --mode infer #预测

4.总结

本项目讲了论文节点分类任务和新冠疫苗任务,并在论文节点分类任务中对代码进行详细讲解。PGL八九系列的项目耦合性比较大,也花了挺久时间研究希望对大家有帮助。

后续将做一次大的总结偏向业务侧该如何落地以及图算法的归纳,之后会进行不定期更新图相关的算法!

  • easydict库和collections库!
  • 从官方数据处理部分,学习到利用np的vstack实现自环边以及知道有向边如何添加反向边的数据——这样的一种代码实现边数据转换的方式!
  • 从模型加载部分,学习了多program执行的操作,理清了program与命名空间之间的联系!
  • 从模型训练部分,强化了执行器执行时,需要传入正确的program以及feed_dict,在pgl中可以使用图Graph自带的to_feed方法返回一个feed_dict数据字典作为初始数据,后边再按需添加新数据!
  • 从model.py学习了模型的组网,以及pgl中conv类下的网络模型方法的调用,方便组网!
  • 重点来了:从build_model.py学习了模型的参数的加载组合,实现统一的处理和返回统一的算子以及参数!

分类:

人工智能

标签:

AI算法

作者介绍

汀丶
V1

将不定期更新关于NLP等领域相关知识,