总结了一下AlexNet的模型结构框图以及对应的模型代码
代码
import torch
from torch import nn
import torch.nn.functional as F
class AlexNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1=nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2)
)
self.layer2=nn.Sequential(
nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2)
)
self.layer3=nn.Sequential(
nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,padding=1),
nn.ReLU()
)
self.layer4=nn.Sequential(
nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,padding=1),
nn.ReLU()
)
self.layer5=nn.Sequential(
nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2)
)
self.layer6=nn.Sequential(
nn.Linear(in_features=6*6*256,out_features=4096),
nn.ReLU(),
nn.Dropout()
)
self.layer7=nn.Sequential(
nn.Linear(in_features=4096,out_features=4096),
nn.ReLU()
)
self.layer8=nn.Linear(in_features=4096,out_features=1000)
def forward(self,x):
x=self.layer1(x)
x=self.layer2(x)
x=self.layer3(x)
x=self.layer4(x)
x=self.layer5(x)
x=x.view(x.size(0),-1)
x=self.layer6(x)
x=self.layer7(x)
x=self.layer8(x)
output=F.softmax(x,dim=1)
return output
版权声明:本文为yuejich原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。