GCL Algorithm (1): introduction and implementation
创始人
2024-06-02 22:58:40
0

目录

1. Concept of GCL

1.1 自监督学习 

1.2 contrastive learning

1.3 data augmentation for GCL

1.3.1 Node dropping

1.3.2 edge perturbation 比扰动

1.3.3 attribute masking 属性掩码

1.3.4 subgraph 子图

2. GCL Algorithms

2.1 常见的图对比算法步骤

2.2 GraphCL Algorithm

2.2.1 Graph data augmentation

2.2.2 GNN-based encoder

2.2.3 Projection head

2.2.4 Contrastive loss function

3. Summary of GCL

4. GCL Implementation

3.1 semi-supervised 实现

3.2 unsupervised 实现

3.3 adversarial 实现

3.4 transfer learning 实现

参考

code

paper


1. Concept of GCL

图对比学习 GCL: GCL是一种针对图数据的自监督学习算法。

--》对给定的大量无标签数据,图对比算法旨在训练出一个图编辑器,即GNN,用以得到图表示向量。

1.1 自监督学习 

自监督学习:主要是利用辅助任务(pretext)从大规模的无监督数据中挖掘出自身的监督信息,通过这种构造的监督信息对网络进行训练,从而可以学习到对下游任务有价值的表征。

--> 通过各种方式从数据本身中为学习算法挖掘到了监督信息。

--> 是否存在不专注于具体细节的表示学习算法来对高层特征编码以实现不同对象之间的区分

1.2 contrastive learning

对比学习通过正面和负面的例子来学习表征。 

1.3 data augmentation for GCL

since data augmentations are the prerequisite for contrastive learning. 

没有数据增强的GCL效果还不如不用! 

1.3.1 Node dropping

随机从图中去掉部分比例的节点来扰动graph的完整性,每个节点的dropping概率服从均匀分布(即 random dropping)。 

1.3.2 edge perturbation 比扰动

随机增加或删除一定比例的边来扰动Graph的邻接关系,每个边的增加或删除概率服从均匀分布。 

1.3.3 attribute masking 属性掩码

随机masking部分节点的属性信息,迫使model使用上下文信息来重构masked node attributes

1.3.4 subgraph 子图

  • 使用随机游走的方式从Graph中提取子图的方法。
  • Graph数据也存在缺少标签或难以标注的问题。

--》将对比学习技术应用于图表示学习任务上。

Graph是一种离散的数据结构,且一些常见的图学习任务中,数据之间往往存在着紧密的关联(e.g. 链接预测)

2. GCL Algorithms

2.1 常见的图对比算法步骤

1) random sampling 一批(batch) graph

2) 对每一个图进行两次随机的data augmentation,增强后的两个新图称为view。

3) 使用带训练的GNN对view进行编码,得到节点表示向量(node representation)和图表示向量(graph representation)。

4) 根据上述表示向量计算InforNCE损失,其中由同一个graph增强出来的view表示相互靠近,由不同graph增强出来的view表示相互远离。

2.2 GraphCL Algorithm

GraphCL for self-supervised pre-training of GNNs. In graph contrastive learning, pre-training is performed through maximizing the agreement between two augmented views of the same graph via a contrastive loss in the latent space.

paper: Graph Contrastive Learning with Augmentations, NeurIPS 2020.

2.2.1 Graph data augmentation

The given graph G undergoes graph data augmentations to obtain two correlated views Gi, Gj, as a positive pair. 

2.2.2 GNN-based encoder

A GNN-based encoder f() extracts graph-level representation vectors hi, hj for augmented graphs Gi, Gj. Graph contrastive learning does not apply any constraint on the GNN architecture.

2.2.3 Projection head

A non-linear transformation g()(激活函数) named projection head maps augmented representations to another latent space space where the contrastive loss is calculated, e.g. MLP, to obtain zi, zj.

2.2.4 Contrastive loss function

A contrastive loss function L() is defined to enforce maximizing the consistency between positive pairs zi, zj compared with negative pairs.

3. Summary of GCL

  • 数据增强对GCL至关重要。without any data augmentation graph contrastive learning is not helpful and often worse.
  • composing different augmentations benefits more.
  • edge perturbation benefits social networks but hurts some biochemical molecules.
  • applying attribute masking achieves better performance in denser graphs.
  • Node dropping and subgraph are generally beneficial across datasets. For subgraph, previous works show that enforcing local (the subgraphs we extract) and global information consistency is helpful for representation learning.

4. GCL Implementation

3.1 semi-supervised 实现

3.2 unsupervised 实现

参考: GraphCL/unsupervised_Cora_Citeseer at master · Shen-Lab/GraphCL · GitHub

1) 构造augmented feature1 and feature2;augmented adjacency matrix1 and matrix2

2) 构建自监督supervised information,由torch.ones全1矩阵和torch.zeros全0矩阵

3)在给定augmented feature and adjacency matrix前提下,由discriminator 1区分第一种数据增强下的feature和shuffled feature,由discriminator 2区分第二种数据增强下的feature和shuffled feature,将结果ret1和ret2相加,作为model学习结果。

4) 将model预测结果与自监督矩阵做反向传播和梯度下降,学习出最优模型参数,以后后面生成feature embeddings。

for epoch in range(nb_epochs):model.train()optimiser.zero_grad()idx = np.random.permutation(nb_nodes)shuf_fts = features[:, idx, :]lbl_1 = torch.ones(batch_size, nb_nodes)  # labelslbl_2 = torch.zeros(batch_size, nb_nodes)lbl = torch.cat((lbl_1, lbl_2), 1)if torch.cuda.is_available():shuf_fts = shuf_fts.cuda()lbl = lbl.cuda()logits = model(features, shuf_fts, aug_features1, aug_features2,sp_adj if sparse else adj,sp_aug_adj1 if sparse else aug_adj1,sp_aug_adj2 if sparse else aug_adj2,sparse, None, None, None, aug_type=aug_type)loss = b_xent(logits, lbl)  # 在augmentation前提下,discriminater学习区分features和shuffle_features.print('Loss:[{:.4f}]'.format(loss.item()))if loss < best:best = lossbest_t = epochcnt_wait = 0torch.save(model.state_dict(), args.save_name)else:cnt_wait += 1if cnt_wait == patience:print('Early stopping!')breakloss.backward()optimiser.step()

3.3 adversarial 实现

3.4 transfer learning 实现

参考

code

  • GitHub - Shen-Lab/GraphCL: [NeurIPS 2020] "Graph Contrastive Learning with Augmentations" by Yuning You, Tianlong Chen, Yongduo Sui, Ting Chen, Zhangyang Wang, Yang Shen

paper

[1] Graph Contrastive Learning with Augmentations, NeurIPS 2020.

相关内容

热门资讯

新还珠格格,欣荣和永琪有个孩子... 新还珠格格,欣荣和永琪有个孩子?不是说永琪从来都没碰过她吗?绵忆到底是他和小燕子的还是欣荣的啊求正解...
中级会计怎么备考?今年几月考试... 中级会计怎么备考?今年几月考试?您好,很高兴为您解答中级会计师考试,教材是根本和基础,所有的题目都是...
继兴业、招商、中信后,邮储银行... (来源:现代商业银行杂志)金融资产投资公司(AIC)队伍再添新员。邮储银行近日发布公告称,该行拟以自...
中央巡视组对陕西开展两个半月常... 转自:北京日报客户端日前,中央第十五巡视组进驻陕西省,将开展为期两个半月左右的常规巡视,并会同陕西省...
柳州幻境空间在哪里 柳州幻境空间在哪里柳州幻境空间是位于广西柳州市城中区华联商闷郑城4楼的室内主题乐园,提供了各种游戏和...
中央巡视组进驻山东 联动巡视济... 转自:央视新闻客户端经党中央批准,二十届中央第六轮巡视将对16个省(自治区、直辖市)开展常规巡视,并...
继续发布暴雨蓝色预警!北京等地... 转自:央视新闻客户端中央气象台19日早6时继续发布暴雨蓝色预警。预计,19日早8时至20日早8时,青...
降妖伏魔篇演员有哪些 降妖伏魔篇演员有哪些文章舒淇程小东黄勃
晚上十一点在河边抓鱼听到有人叫... 晚上十一点在河边抓鱼听到有人叫我小名声音跟我一个朋友一样,电筒照却没有发现有人而且我女朋友也听见了不...
属猴的为什么吸引属狗的人 属猴的为什么吸引属狗的人属相狗虽不善甜言蜜语,为人多有情感之被捉,然其铅轮内心却多有向往甜蜜幸福之生...