使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。
构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。
加载预训练alexnet之后,可以print出来查看模型的结构及信息:
model = models.alexnet(pretrained=True) print(model)
分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。
下面放出完整的搭建代码:
import torch.nn as nn from torchvision import models class BuildAlexNet(nn.Module): def __init__(self, model_type, n_output): super(BuildAlexNet, self).__init__() self.model_type = model_type if model_type == 'pre': model = models.alexnet(pretrained=True) self.features = model.features fc1 = nn.Linear(9216, 4096) fc1.bias = model.classifier[1].bias fc1.weight = model.classifier[1].weight fc2 = nn.Linear(4096, 4096) fc2.bias = model.classifier[4].bias fc2.weight = model.classifier[4].weight self.classifier = nn.Sequential( nn.Dropout(), fc1, nn.ReLU(inplace=True), nn.Dropout(), fc2, nn.ReLU(inplace=True), nn.Linear(4096, n_output)) #或者直接修改为 # model.classifier[6]==nn.Linear(4096,n_output) # self.classifier = model.classifier if model_type == 'new': self.features = nn.Sequential( nn.Conv2d(3, 64, 11, 4, 2), nn.ReLU(inplace = True), nn.MaxPool2d(3, 2, 0), nn.Conv2d(64, 192, 5, 1, 2), nn.ReLU(inplace=True), nn.MaxPool2d(3, 2, 0), nn.Conv2d(192, 384, 3, 1, 1), nn.ReLU(inplace = True), nn.Conv2d(384, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(3, 2, 0)) self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(9216, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, n_output)) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) out = self.classifier(x) return out
微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。
网络搭好之后可以小小的测试一下以检验维度是否正确。
import numpy as np from torch.autograd import Variable import torch if __name__ == '__main__': model_type = 'pre' n_output = 10 alexnet = BuildAlexNet(model_type, n_output) print(alexnet) x = np.random.rand(1,3,224,224) x = x.astype(np.float32) x_ts = torch.from_numpy(x) x_in = Variable(x_ts) y = alexnet(x_in)
这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。
输出y.data.numpy()可得10维输出,表明网络搭建正确。
以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持猪先飞。
相关文章
pytorch nn.Conv2d()中的padding以及输出大小方式
今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27Linux安装Pytorch1.8GPU(CUDA11.1)的实现
这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25- 这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
- 今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
- 这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
- 这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04Pytorch 的损失函数Loss function使用详解
今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
- 今天小编就为大家分享一篇Pytorch实现LSTM和GRU示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
- 这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25
解决Pytorch dataloader时报错每个tensor维度不一样的问题
这篇文章主要介绍了解决Pytorch dataloader时报错每个tensor维度不一样的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-28pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch深度学习中对softmax实现进行了详细解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步...2021-09-30
- 今天小编就为大家分享一篇Pytorch 计算误判率,计算准确率,计算召回率的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
- 这篇文章主要介绍了pytorch中的squeeze函数、cat函数使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-20
- 这篇文章主要介绍了Pytorch如何切换 cpu和gpu的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-01
- 今天小编就为大家分享一篇pytorch动态网络以及权重共享实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-29