ALEXNET论文及其复现代码
创始人
2024-02-18 05:23:23
0

Alexnet-2012

文章目录

  • Alexnet-2012
    • 研究背景
      • top 5 error
    • 研究意义
    • 论文精读
      • Abstruct
      • 1. Introduction
      • 3. Architecture
        • 3.1 ReLU Nonlinearity
        • 3.2 Training on Multiple GPUs
        • 3.3 LRN local response normalization
        • 3.4 overlapping pooling
        • 3.5 Overall Architecture
      • 4. Reducing Overfitting
        • 4.1 Data Augmentation 图像增强
        • 4.2 DropOut
      • 实验结果和分析
        • 卷积核的可视化
        • 特征的相似性
    • 关键代码
  • torchvision介绍

研究背景

ILSVRC-2012 ImageNet Large Scale Visual Recognition Challenge.

类别训练数据测试数据图片格式
Mnist1050 00010 000Gray 28*28
Cifar-101050 00010 000RGB 32*32*3
ILSVRC-201210001200 000150 000RGB

ImageNet 数据集包含21,841个 类别, 14,197,122张图片。

top 5 error

AlexNet 在ILSVRC-2012以超出第二名10.9个百分点夺冠。

ModelTop-1 valTop5 valTop-5 test
SIFT+FVs--26.2%
1 CNN40.7%18.2%-
5 CNNs38.1%16.4%16.4%
1CNN*39.0%16.6%-
7CNN*s36.7%15.4%15.3%
  1. SIFT+FVS: ILSVRC-2012 分类任务第二名
  2. 1CNN 训练一个AlexNet
  3. 5CNNs 训练五个AlexNet取平均值。
  4. 1CNN* 在最后一个池化层后,额外添加第六个卷积层,并使用ImageNet 2011 秋季数据集上预训练。
  5. 7CNN*s 两个预训练微调,与5CNNs取平均值。

研究意义

  1. 里程碑式的论文
  2. 加速计算机视觉应用落地。 端到端式的,不需要再加特征工程。

论文精读

Abstruct

  1. ILSVRC-2010的120万张图片上训练AlexNet,最有结果: top1 error: 37.5, top-5 error 17%.
  2. 该网络由5个卷积层和3个全联接层组成,共计6000万个参数, 65万个神经元。
  3. 为加快训练,采用ReLU + GPU进行训练。
  4. 为减轻过拟合,采用Dropout.
  5. 基于上面的技巧,在ILSVRC-2012以超过第二名10.9个百分点的成绩夺冠。

1. Introduction

3. Architecture

3.1 ReLU Nonlinearity

ReLU Nonlinearity
f(x)=max(0,x)f(x) = max(0,x) f(x)=max(0,x)
Tanh 激活函数
f(x)=11+e−xf(x) = \frac{1}{1+ e^{-x}} f(x)=1+e−x1​

Relu的优点

  1. 从下面的图中可以看出,Re LU的训练速度是Tanh的6倍。用ReLU激活函数来训练模型是非常快的。
  2. 防止梯度消失。
  3. 使得网络具有稀疏性。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-u5AuKAgR-1669454493669)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125142750067.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jedPVQAr-1669454493670)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143351852.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xsITmv18-1669454493670)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143323365.png)]

3.2 Training on Multiple GPUs

3.3 LRN local response normalization

局部响应标准化

局部响应标准化:有助于AlexNet泛化能力的提升,受到真实神经元侧抑制启发。

侧抑制:细胞分化变为不同时,它会对周围细胞产生抑制信号,阻止它们向相同的方向分化,最终表现为细胞命运的不同。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PYPDrwYz-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143654127.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B1P6Olno-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125144420240.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bl8cgN2z-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125144449149.png)]

3.4 overlapping pooling

带重叠的池化层

3.5 Overall Architecture

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PDYK01zq-1669454493672)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125114705661.png)]

  1. 双GPU训练,CPU之间在某些层之间进行通信,
  2. 一共有八层:5个卷积层和3个全联接层。最后一个全联接层式一个1000类的softmax层。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RM3vdF0d-1669454493673)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125140113824.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4zYD3PC0-1669454493673)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125140142396.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-W7ktWEks-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125141951508.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rsxS6DuG-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125142006998.png)]

4. Reducing Overfitting

4.1 Data Augmentation 图像增强

方法1. 针对位置

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Gg67stBy-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125145455866.png)]

方法2 针对色彩

通过PCA方法修改RGB通道的像素值,实现颜色扰动,效果有限。

4.2 DropOut

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YcagXjJc-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125150515033.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yu9zZ1aw-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125150551510.png)]

实验结果和分析

卷积核的可视化

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EwOq5u9N-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125151740366.png)]

特征的相似性

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QtFT1dfh-1669454493676)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125152844529.png)]

相似图片的第二个全联接层输出特征的欧式距离相近。

启发:

可用alexnet 提取高级特征进行图像检索、图像聚类、图像编码。

关键代码

torch.topk(input, k, dim = None, largest=True, sorted = True, out = None)
"""
功能: 找出前k大的数据,及其索引序列号1. input : 张量
2. k 决定选取k个值
3. dim: 索引维度返回
1. Tensor: 前k个最大的值
2. LongTensor: 前k大的值所在的位置
"""
transforms.FiveCrop(size)
transforms.TenCrop(size, vertical_flip = False)"""
功能:在图片的上下左右及其中心裁出尺寸为size的五张图片, TenCrop 对这五张图片进行水平或者垂直镜像获得10张图片。
1. size: 所需要裁剪的尺寸
2. vertical_flip 是否要垂直翻转
"""
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range = None, scale_each = False, pad_value = 0)
"""
功能:制作网格图像
1. tensor: 图像数据, B*C*H*W 形式
2. nrow : 行数(列数自动计算)
3. padding : 图像间距(像素单位)
4. normalize: 是否将像素值标准化
5. range 标准化范围
6. scale_each: 是否单张图片维度标准化
7. pad_value: padding 的像素值
"""

torchvision介绍

torchvision 是pytoch的一个图形库, 他服务于pytorch深度学习框架, 主要用来构建计算机视觉模型。

torchvision.datasets: 一些加载数据的函数及常用的数据接口
torchvision.models: 包含常用的深度学习模型(含预训练模型)
torchvision.transforms: 常用的图像变化,例如裁剪,旋转等
torchvision.utils: 其他的一些有用的方法
class torchvision.transforms.Compose(transforms):# Composes several transforms together# parameters: transforms (list of transform objects) -list of transforms to composetransforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
model.eval()
# 模型中有BatchNormalization和Dropout,在预测时使用model.eval()后会将其关闭以免影响预测结果。# https://blog.csdn.net/qq_38410428/article/details/101102075
      
# -*- coding: utf-8 -*-
"""
# @file name  : train_alexnet.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2020-02-14
# @brief      : alexnet traning
"""import osimport numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from ToolsUtils.my_dataset import CatDogDataset
from torch.utils.data import DataLoaderBASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def get_model(path_state_dict, vis_model=False):"""创建模型,加载参数:param path_state_dict::return:"""model = models.alexnet()pretrained_state_dict = torch.load(path_state_dict)model.load_state_dict(pretrained_state_dict)if vis_model:from torchsummary import summarysummary(model, input_size=(3, 224, 224), device="cpu")model.to(device)return modelif __name__ == "__main__":# configdata_dir = os.path.join(BASE_DIR, "..", "data", "train")path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")num_classes = 2MAX_EPOCH = 3       # 可自行修改BATCH_SIZE = 128    # 可自行修改LR = 0.001          # 可自行修改log_interval = 1    # 可自行修改val_interval = 1    # 可自行修改classes = 2start_epoch = -1lr_decay_step = 1   # 可自行修改# ============================ step 1/5 数据 ============================norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((256)),      # (256, 256) 区别transforms.CenterCrop(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),])normalizes = transforms.Normalize(norm_mean, norm_std)valid_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.TenCrop(224, vertical_flip=False),transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),])# 构建MyDataset实例train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=4)# ============================ step 2/5 模型 ============================alexnet_model = get_model(path_state_dict, False)num_ftrs = alexnet_model.classifier._modules["6"].in_featuresalexnet_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)alexnet_model.to(device)# ============================ step 3/5 损失函数 ============================criterion = nn.CrossEntropyLoss()# ============================ step 4/5 优化器 ============================# 冻结卷积层flag = 0# flag = 1if flag:fc_params_id = list(map(id, alexnet_model.classifier.parameters()))  # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in fc_params_id, alexnet_model.parameters())optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0.1},  # 0{'params': alexnet_model.classifier.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(alexnet_model.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=5)# ============================ step 5/5 训练 ============================train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.alexnet_model.train()for i, data in enumerate(train_loader):# if i > 1:#     break# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = alexnet_model(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.alexnet_model.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)bs, ncrops, c, h, w = inputs.size()     # [4, 10, 3, 224, 224outputs = alexnet_model(inputs.view(-1, c, h, w))outputs_avg = outputs.view(bs, ncrops, -1).mean(1)loss = criterion(outputs_avg, labels)_, predicted = torch.max(outputs_avg.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val/len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))alexnet_model.train()train_x = range(len(train_curve))train_y = train_curvetrain_iters = len(train_loader)valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterationsvalid_y = valid_curveplt.plot(train_x, train_y, label='Train')plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')plt.ylabel('loss value')plt.xlabel('Iteration')plt.savefig()plt.show()

相关内容

热门资讯

六件劳动争议典型案例传递实质公... 转自:河工新闻网  日前,河北省高级人民法院向社会通报2020-2024年河北法院劳动者权益司法保护...
5月12日生意社菜籽粕基准价为... 生意社05月12日讯 5月12日,生意社菜籽粕基准价为2506.67元/吨,与本月初(...
公益纪实节目《艾在归途》:讲述... 转自:中国台湾网  “边看边流泪,亲情深似海,两岸同胞都是中国人”“中国人是有根的!这就是两岸一家亲...
中央气象台:天 气 公 报(2... 来源:中央气象台网站华北黄淮部分地区有高温天气发展南方地区将有新一轮降水过程摘要:国内方面,昨日,全...
快速公交天桥站设施齐备却10年... 转自:北京日报客户端市民近日向北京日报客户端反映,快速公交1号线天桥站建成已有十个年头,设施齐全却一...
宁德时代:副董事长李平及其配偶... 转自:财联社【宁德时代:副董事长李平及其配偶拟向复旦大学教育发展基金会无偿捐赠405万股股票】财联社...
CBOT小麦下跌,因作物状况持... 原标题:CBOT小麦下跌,因作物状况持续改善 来源:南方小麦网周五,芝加哥期货交易所(CBOT...
从“一类事”到“多类事”,舟山... 转自:中国环境网作为现存已知最古老的鱼类之一,中华鲟已在地球上生存了1.4亿年,被誉为“水中活化石”...
海南对7市县开展生态环境执法稽... 转自:中国环境网为规范生态环境行政执法行为,近日,海南省生态环境厅对海口、三亚、澄迈、临高、屯昌、乐...
吉林长龙药业将于7月18日派发... .ct_hqimg {margin: 10px 0;} .hqimg_wrapper {text-a...
阿里夸克深度搜索:让AI更懂普... 来源:钛媒体AI大模型以深度思考惊艳了世界之后,下一个阶段,AI应用要往哪去?阿里AI 旗舰应用——...
5月12日生意社干香菇基准价为... 生意社05月12日讯 5月12日,生意社干香菇基准价为55.88元/公斤,与本月初(5...
沙特国民银行与阿雷布资本签58... (转自:观点网)观点网讯:5月12日,沙特国民银行与阿雷布资本签署一项符合伊斯兰教法的融资协议,融资...
今年以来险企已举牌十三次 转自:中国银行保险报网□本报记者 朱艳霞近日,中邮保险通过协议转让方式受让东航物流7942.01万股...
黑龙江省实现国家级幸福河湖建设... 转自:黑龙江发布近日,水利部办公厅印发《关于实施2025年幸福河湖建设项目的通知》,公布2025年幸...
播恩集团:2025年4月30日... 投资者提问:请问截止2025年4月30日,公司股东人数多少?董秘回答(播恩集团SZ001366):尊...
第二十届“光博会”本周武汉开幕...     编者按    第20届“光博会”如约而至,光谷、武汉、湖北,本周将再次成为全球光电子信息产业...
五大国有行AIC设立八年总资产...   长江商报消息 ●长江商报记者 徐佳  时隔八年,金融资产投资公司(即“AIC”)队伍扩容。  5...
药师帮5月9日回购30万股股份 .ct_hqimg {margin: 10px 0;} .hqimg_wrapper {text-a...
【深化“三抓三促”行动 群众有... 【深化“三抓三促”行动 群众有话说】“气象服务让我们产出好茶”  新甘肃·甘肃日报记者 海晓宁  近...