ALEXNET论文及其复现代码
创始人
2024-02-18 05:23:23
0

Alexnet-2012

文章目录

  • Alexnet-2012
    • 研究背景
      • top 5 error
    • 研究意义
    • 论文精读
      • Abstruct
      • 1. Introduction
      • 3. Architecture
        • 3.1 ReLU Nonlinearity
        • 3.2 Training on Multiple GPUs
        • 3.3 LRN local response normalization
        • 3.4 overlapping pooling
        • 3.5 Overall Architecture
      • 4. Reducing Overfitting
        • 4.1 Data Augmentation 图像增强
        • 4.2 DropOut
      • 实验结果和分析
        • 卷积核的可视化
        • 特征的相似性
    • 关键代码
  • torchvision介绍

研究背景

ILSVRC-2012 ImageNet Large Scale Visual Recognition Challenge.

类别训练数据测试数据图片格式
Mnist1050 00010 000Gray 28*28
Cifar-101050 00010 000RGB 32*32*3
ILSVRC-201210001200 000150 000RGB

ImageNet 数据集包含21,841个 类别, 14,197,122张图片。

top 5 error

AlexNet 在ILSVRC-2012以超出第二名10.9个百分点夺冠。

ModelTop-1 valTop5 valTop-5 test
SIFT+FVs--26.2%
1 CNN40.7%18.2%-
5 CNNs38.1%16.4%16.4%
1CNN*39.0%16.6%-
7CNN*s36.7%15.4%15.3%
  1. SIFT+FVS: ILSVRC-2012 分类任务第二名
  2. 1CNN 训练一个AlexNet
  3. 5CNNs 训练五个AlexNet取平均值。
  4. 1CNN* 在最后一个池化层后,额外添加第六个卷积层,并使用ImageNet 2011 秋季数据集上预训练。
  5. 7CNN*s 两个预训练微调,与5CNNs取平均值。

研究意义

  1. 里程碑式的论文
  2. 加速计算机视觉应用落地。 端到端式的,不需要再加特征工程。

论文精读

Abstruct

  1. ILSVRC-2010的120万张图片上训练AlexNet,最有结果: top1 error: 37.5, top-5 error 17%.
  2. 该网络由5个卷积层和3个全联接层组成,共计6000万个参数, 65万个神经元。
  3. 为加快训练,采用ReLU + GPU进行训练。
  4. 为减轻过拟合,采用Dropout.
  5. 基于上面的技巧,在ILSVRC-2012以超过第二名10.9个百分点的成绩夺冠。

1. Introduction

3. Architecture

3.1 ReLU Nonlinearity

ReLU Nonlinearity
f(x)=max(0,x)f(x) = max(0,x) f(x)=max(0,x)
Tanh 激活函数
f(x)=11+e−xf(x) = \frac{1}{1+ e^{-x}} f(x)=1+e−x1​

Relu的优点

  1. 从下面的图中可以看出,Re LU的训练速度是Tanh的6倍。用ReLU激活函数来训练模型是非常快的。
  2. 防止梯度消失。
  3. 使得网络具有稀疏性。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-u5AuKAgR-1669454493669)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125142750067.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jedPVQAr-1669454493670)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143351852.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xsITmv18-1669454493670)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143323365.png)]

3.2 Training on Multiple GPUs

3.3 LRN local response normalization

局部响应标准化

局部响应标准化:有助于AlexNet泛化能力的提升,受到真实神经元侧抑制启发。

侧抑制:细胞分化变为不同时,它会对周围细胞产生抑制信号,阻止它们向相同的方向分化,最终表现为细胞命运的不同。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PYPDrwYz-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143654127.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B1P6Olno-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125144420240.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bl8cgN2z-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125144449149.png)]

3.4 overlapping pooling

带重叠的池化层

3.5 Overall Architecture

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PDYK01zq-1669454493672)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125114705661.png)]

  1. 双GPU训练,CPU之间在某些层之间进行通信,
  2. 一共有八层:5个卷积层和3个全联接层。最后一个全联接层式一个1000类的softmax层。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RM3vdF0d-1669454493673)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125140113824.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4zYD3PC0-1669454493673)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125140142396.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-W7ktWEks-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125141951508.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rsxS6DuG-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125142006998.png)]

4. Reducing Overfitting

4.1 Data Augmentation 图像增强

方法1. 针对位置

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Gg67stBy-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125145455866.png)]

方法2 针对色彩

通过PCA方法修改RGB通道的像素值,实现颜色扰动,效果有限。

4.2 DropOut

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YcagXjJc-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125150515033.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yu9zZ1aw-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125150551510.png)]

实验结果和分析

卷积核的可视化

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EwOq5u9N-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125151740366.png)]

特征的相似性

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QtFT1dfh-1669454493676)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125152844529.png)]

相似图片的第二个全联接层输出特征的欧式距离相近。

启发:

可用alexnet 提取高级特征进行图像检索、图像聚类、图像编码。

关键代码

torch.topk(input, k, dim = None, largest=True, sorted = True, out = None)
"""
功能: 找出前k大的数据,及其索引序列号1. input : 张量
2. k 决定选取k个值
3. dim: 索引维度返回
1. Tensor: 前k个最大的值
2. LongTensor: 前k大的值所在的位置
"""
transforms.FiveCrop(size)
transforms.TenCrop(size, vertical_flip = False)"""
功能:在图片的上下左右及其中心裁出尺寸为size的五张图片, TenCrop 对这五张图片进行水平或者垂直镜像获得10张图片。
1. size: 所需要裁剪的尺寸
2. vertical_flip 是否要垂直翻转
"""
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range = None, scale_each = False, pad_value = 0)
"""
功能:制作网格图像
1. tensor: 图像数据, B*C*H*W 形式
2. nrow : 行数(列数自动计算)
3. padding : 图像间距(像素单位)
4. normalize: 是否将像素值标准化
5. range 标准化范围
6. scale_each: 是否单张图片维度标准化
7. pad_value: padding 的像素值
"""

torchvision介绍

torchvision 是pytoch的一个图形库, 他服务于pytorch深度学习框架, 主要用来构建计算机视觉模型。

torchvision.datasets: 一些加载数据的函数及常用的数据接口
torchvision.models: 包含常用的深度学习模型(含预训练模型)
torchvision.transforms: 常用的图像变化,例如裁剪,旋转等
torchvision.utils: 其他的一些有用的方法
class torchvision.transforms.Compose(transforms):# Composes several transforms together# parameters: transforms (list of transform objects) -list of transforms to composetransforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
model.eval()
# 模型中有BatchNormalization和Dropout,在预测时使用model.eval()后会将其关闭以免影响预测结果。# https://blog.csdn.net/qq_38410428/article/details/101102075
      
# -*- coding: utf-8 -*-
"""
# @file name  : train_alexnet.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2020-02-14
# @brief      : alexnet traning
"""import osimport numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from ToolsUtils.my_dataset import CatDogDataset
from torch.utils.data import DataLoaderBASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def get_model(path_state_dict, vis_model=False):"""创建模型,加载参数:param path_state_dict::return:"""model = models.alexnet()pretrained_state_dict = torch.load(path_state_dict)model.load_state_dict(pretrained_state_dict)if vis_model:from torchsummary import summarysummary(model, input_size=(3, 224, 224), device="cpu")model.to(device)return modelif __name__ == "__main__":# configdata_dir = os.path.join(BASE_DIR, "..", "data", "train")path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")num_classes = 2MAX_EPOCH = 3       # 可自行修改BATCH_SIZE = 128    # 可自行修改LR = 0.001          # 可自行修改log_interval = 1    # 可自行修改val_interval = 1    # 可自行修改classes = 2start_epoch = -1lr_decay_step = 1   # 可自行修改# ============================ step 1/5 数据 ============================norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((256)),      # (256, 256) 区别transforms.CenterCrop(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),])normalizes = transforms.Normalize(norm_mean, norm_std)valid_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.TenCrop(224, vertical_flip=False),transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),])# 构建MyDataset实例train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=4)# ============================ step 2/5 模型 ============================alexnet_model = get_model(path_state_dict, False)num_ftrs = alexnet_model.classifier._modules["6"].in_featuresalexnet_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)alexnet_model.to(device)# ============================ step 3/5 损失函数 ============================criterion = nn.CrossEntropyLoss()# ============================ step 4/5 优化器 ============================# 冻结卷积层flag = 0# flag = 1if flag:fc_params_id = list(map(id, alexnet_model.classifier.parameters()))  # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in fc_params_id, alexnet_model.parameters())optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0.1},  # 0{'params': alexnet_model.classifier.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(alexnet_model.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=5)# ============================ step 5/5 训练 ============================train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.alexnet_model.train()for i, data in enumerate(train_loader):# if i > 1:#     break# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = alexnet_model(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.alexnet_model.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)bs, ncrops, c, h, w = inputs.size()     # [4, 10, 3, 224, 224outputs = alexnet_model(inputs.view(-1, c, h, w))outputs_avg = outputs.view(bs, ncrops, -1).mean(1)loss = criterion(outputs_avg, labels)_, predicted = torch.max(outputs_avg.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val/len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))alexnet_model.train()train_x = range(len(train_curve))train_y = train_curvetrain_iters = len(train_loader)valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterationsvalid_y = valid_curveplt.plot(train_x, train_y, label='Train')plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')plt.ylabel('loss value')plt.xlabel('Iteration')plt.savefig()plt.show()

相关内容

热门资讯

杨字的含义 杨字的含义 扬:张扬,自得的意思【解释】:趾高:走路时脚抬得很高;气扬:意气扬扬。走路时脚抬得很...
有梦子的四字成语? 有梦子的四字成语?白日做梦、魂牵梦萦、如梦初醒、梦寐以求、酣然入梦、半梦半醒、重温旧梦、夜长梦多、同...
恶人自有恶人磨 恶人自有恶人磨恶人自有恶人磨 (è rén zì yǒu è rén mó)解释:凶恶成性的人自然...
相对论中,火车断桥问题的答案是... 相对论中,火车断桥问题的答案是什么?这个假设唯一只有一个问题。败伏什么叫做“同时”?“只有两个发射器...
虎什么熊的成语 虎什么熊的成语这不是闹经急转弯虎背熊腰hǔ bèi xióng yāo成语解释如虎般宽厚的背;似熊样...
《亡念之扎姆德》男主角最后跟谁... 《亡念之扎姆德》男主角最后跟谁在一起?男主石化了九年,女主每天都来和他说话,然后九年后的第二天男主解...
火影忍者动画和漫画貌似不一样,... 火影忍者动画和漫画貌似不一样,海贼王动画和漫画一样吗?总是有些偏差的吧。。个人比较忠实原作。海贼王没...
让人非我弱,得志莫离群 让人非我弱,得志莫离群像投鼠忌器一样的意思吧,我躲,不是我怕你,而是我心有顾忌.不是因为势力差距而起...
《北宋小厨师》这本书更到现在男... 《北宋小厨师》这本书更到现在男猪脚泡到李师师和李清照了吗没有0.0....还没有啊因为还没结局
路边油炸的小摊上的酱是怎么做的... 路边油炸的小摊上的酱是怎么做的!要是家用,那可以选择用芝麻浆来做主配料.芝麻浆和水要1:1(水最好是...
《超禁忌游戏-五十分之一》应该... 《超禁忌游戏-五十分之一》应该完结了吧你要的是完整版的,但负责任地告诉你,现在不可能有,有也是骗你的...
公共经济学 答案 公共经济学 答案这个真不知道~!~谢谢~!~1.D2.D3.C4.B5.C6.D7.C8.B9.B1...
证券投资学 跟投资学有什么区别 证券投资学 跟投资学有什么区别投资学包括证券投资学。投资学包括各方面的投资学,比如黄金投资,期货投资...
忘记名字了,就是男主得到系统打... 忘记名字了,就是男主得到系统打英雄联盟,在联盟里边开挂可以变身眼可以变成野怪可以身穿求这部小说名字你...
个性签名为了你我愿意变成魔于全... 个性签名为了你我愿意变成魔于全世界为敌不爱那么多,只爱一点点,别人眉来又眼去,我只偷看你一眼。不要走...
野钓实用技巧 黑坑钓鱼技巧? 野钓实用技巧 黑坑钓鱼技巧?钓什么鱼要了解鱼的习性,了解对象鱼生活在哪个水层,喜欢吃什么食物,然后根...
江哲是那本书的? 江哲是那本书的?字随云的是《随波逐流之一代军师》字守义的是《三国之宅行天下》呵呵……这两本小说的江哲...
满满的生活经历是啥意思? 满满的生活经历是啥意思?满满的生活经历,说明的是这个人的生活阅历很深。
哪部国产青春剧比较贴近现实? 哪部国产青春剧比较贴近现实?《最好的我们》比较贴近现实,讲述的就是校园爱情故事,说的就是真实的高中生...
天涯海角与君共度 出自哪首歌呢... 天涯海角与君共度 出自哪首歌呢。云中歌主题曲丝罗李宇春的丝罗你好。楼主。李宇春《丝罗》伊本丝萝愿托乔...