GCL Algorithm (1): introduction and implementation
创始人
2024-06-02 22:58:40
0

目录

1. Concept of GCL

1.1 自监督学习 

1.2 contrastive learning

1.3 data augmentation for GCL

1.3.1 Node dropping

1.3.2 edge perturbation 比扰动

1.3.3 attribute masking 属性掩码

1.3.4 subgraph 子图

2. GCL Algorithms

2.1 常见的图对比算法步骤

2.2 GraphCL Algorithm

2.2.1 Graph data augmentation

2.2.2 GNN-based encoder

2.2.3 Projection head

2.2.4 Contrastive loss function

3. Summary of GCL

4. GCL Implementation

3.1 semi-supervised 实现

3.2 unsupervised 实现

3.3 adversarial 实现

3.4 transfer learning 实现

参考

code

paper


1. Concept of GCL

图对比学习 GCL: GCL是一种针对图数据的自监督学习算法。

--》对给定的大量无标签数据,图对比算法旨在训练出一个图编辑器,即GNN,用以得到图表示向量。

1.1 自监督学习 

自监督学习:主要是利用辅助任务(pretext)从大规模的无监督数据中挖掘出自身的监督信息,通过这种构造的监督信息对网络进行训练,从而可以学习到对下游任务有价值的表征。

--> 通过各种方式从数据本身中为学习算法挖掘到了监督信息。

--> 是否存在不专注于具体细节的表示学习算法来对高层特征编码以实现不同对象之间的区分

1.2 contrastive learning

对比学习通过正面和负面的例子来学习表征。 

1.3 data augmentation for GCL

since data augmentations are the prerequisite for contrastive learning. 

没有数据增强的GCL效果还不如不用! 

1.3.1 Node dropping

随机从图中去掉部分比例的节点来扰动graph的完整性,每个节点的dropping概率服从均匀分布(即 random dropping)。 

1.3.2 edge perturbation 比扰动

随机增加或删除一定比例的边来扰动Graph的邻接关系,每个边的增加或删除概率服从均匀分布。 

1.3.3 attribute masking 属性掩码

随机masking部分节点的属性信息,迫使model使用上下文信息来重构masked node attributes

1.3.4 subgraph 子图

  • 使用随机游走的方式从Graph中提取子图的方法。
  • Graph数据也存在缺少标签或难以标注的问题。

--》将对比学习技术应用于图表示学习任务上。

Graph是一种离散的数据结构,且一些常见的图学习任务中,数据之间往往存在着紧密的关联(e.g. 链接预测)

2. GCL Algorithms

2.1 常见的图对比算法步骤

1) random sampling 一批(batch) graph

2) 对每一个图进行两次随机的data augmentation,增强后的两个新图称为view。

3) 使用带训练的GNN对view进行编码,得到节点表示向量(node representation)和图表示向量(graph representation)。

4) 根据上述表示向量计算InforNCE损失,其中由同一个graph增强出来的view表示相互靠近,由不同graph增强出来的view表示相互远离。

2.2 GraphCL Algorithm

GraphCL for self-supervised pre-training of GNNs. In graph contrastive learning, pre-training is performed through maximizing the agreement between two augmented views of the same graph via a contrastive loss in the latent space.

paper: Graph Contrastive Learning with Augmentations, NeurIPS 2020.

2.2.1 Graph data augmentation

The given graph G undergoes graph data augmentations to obtain two correlated views Gi, Gj, as a positive pair. 

2.2.2 GNN-based encoder

A GNN-based encoder f() extracts graph-level representation vectors hi, hj for augmented graphs Gi, Gj. Graph contrastive learning does not apply any constraint on the GNN architecture.

2.2.3 Projection head

A non-linear transformation g()(激活函数) named projection head maps augmented representations to another latent space space where the contrastive loss is calculated, e.g. MLP, to obtain zi, zj.

2.2.4 Contrastive loss function

A contrastive loss function L() is defined to enforce maximizing the consistency between positive pairs zi, zj compared with negative pairs.

3. Summary of GCL

  • 数据增强对GCL至关重要。without any data augmentation graph contrastive learning is not helpful and often worse.
  • composing different augmentations benefits more.
  • edge perturbation benefits social networks but hurts some biochemical molecules.
  • applying attribute masking achieves better performance in denser graphs.
  • Node dropping and subgraph are generally beneficial across datasets. For subgraph, previous works show that enforcing local (the subgraphs we extract) and global information consistency is helpful for representation learning.

4. GCL Implementation

3.1 semi-supervised 实现

3.2 unsupervised 实现

参考: GraphCL/unsupervised_Cora_Citeseer at master · Shen-Lab/GraphCL · GitHub

1) 构造augmented feature1 and feature2;augmented adjacency matrix1 and matrix2

2) 构建自监督supervised information,由torch.ones全1矩阵和torch.zeros全0矩阵

3)在给定augmented feature and adjacency matrix前提下,由discriminator 1区分第一种数据增强下的feature和shuffled feature,由discriminator 2区分第二种数据增强下的feature和shuffled feature,将结果ret1和ret2相加,作为model学习结果。

4) 将model预测结果与自监督矩阵做反向传播和梯度下降,学习出最优模型参数,以后后面生成feature embeddings。

for epoch in range(nb_epochs):model.train()optimiser.zero_grad()idx = np.random.permutation(nb_nodes)shuf_fts = features[:, idx, :]lbl_1 = torch.ones(batch_size, nb_nodes)  # labelslbl_2 = torch.zeros(batch_size, nb_nodes)lbl = torch.cat((lbl_1, lbl_2), 1)if torch.cuda.is_available():shuf_fts = shuf_fts.cuda()lbl = lbl.cuda()logits = model(features, shuf_fts, aug_features1, aug_features2,sp_adj if sparse else adj,sp_aug_adj1 if sparse else aug_adj1,sp_aug_adj2 if sparse else aug_adj2,sparse, None, None, None, aug_type=aug_type)loss = b_xent(logits, lbl)  # 在augmentation前提下,discriminater学习区分features和shuffle_features.print('Loss:[{:.4f}]'.format(loss.item()))if loss < best:best = lossbest_t = epochcnt_wait = 0torch.save(model.state_dict(), args.save_name)else:cnt_wait += 1if cnt_wait == patience:print('Early stopping!')breakloss.backward()optimiser.step()

3.3 adversarial 实现

3.4 transfer learning 实现

参考

code

  • GitHub - Shen-Lab/GraphCL: [NeurIPS 2020] "Graph Contrastive Learning with Augmentations" by Yuning You, Tianlong Chen, Yongduo Sui, Ting Chen, Zhangyang Wang, Yang Shen

paper

[1] Graph Contrastive Learning with Augmentations, NeurIPS 2020.

相关内容

热门资讯

股价走高触发强赎 7月将有两只... 近日,银行正股股价表现强势,多只银行转债触发强制赎回条款。7月1日是杭银转债最后一个交易日,其最后转...
300548,“改名”,股价历... 科技股和顺周期板块再现“跷跷板”走势。今天上午,顺周期板块走强,银行、有色金属、白酒、新能源等板块上...
明阳电路在昆山投资成立集成电路... 人民财讯7月2日电,企查查APP显示,近日,昆山华芯微测集成电路有限公司成立,法定代表人为窦旭才,经...
民生银行“易创E贷”获“202... 在近日举行的第七届数字普惠金融大会上,民生银行“易创E贷”产品凭借在服务科技型中小微企业的创新与实践...
机器人大军逼近,很快,亚马逊的...   炒股就看金麒麟分析师研报,权威,专业,及时,全面,助您挖掘潜力主题机会! 亚马逊正迅速接近仓库...
广东省教育考试院通报:不存在中... 7月2日,广东省教育考试院发布通报称,7月1日下午广东省初中学业水平考试(以下简称“中考”)数学科目...
国际油价承压 中东油企拟放缓全... 财联社7月2日讯(编辑 秦嘉禾)在国际油价下行压力加剧背景下,中东两大国有能源巨头——沙特阿美(Sa...
光伏50ETF、光伏龙头ETF... 光伏设备板块走强,亿晶光电、欧晶科技涨停。银华光伏50ETF、汇添富光伏龙头ETF、浦银安盛光伏龙头...
世界银行发布重磅预测:黄金今年... 错过上半年,别再错过下半年!世界银行继续看好贵金属前景,黄金、白银、铂金或将继续延续强势……世界银行...
洪灝称高分红投资策略应继续有所... 【#洪灝称高分红投资策略应继续有所表现#】中国股市近期一个显著趋势是分红和回购活动显著增加。其中,许...