解决Pytorch修改预训练模型时遇到key不匹配的情况
一、Pytorch修改预训练模型时遇到key不匹配
最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。
在我使用新赋值的网络模型时出现了key不匹配的问题
#加载后保存(未修改网络) base_weights = torch.load(args.save_folder + args.basenet) ssd_net.vgg.load_state_dict(base_weights) torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型 ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes']) net = ssd_net ... if args.resume: ... else: base_weights = torch.load(args.save_folder + args.basenet) #args.basenet为ssd_base.pth print('Loading base network...') ssd_net.vgg.load_state_dict(base_weights)
此时会如下出错误:
Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)
…
RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.
说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”
我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。
现在的问题是因为自己定义保存的模型key参数多了一个前缀。
可以通过如下语句进行修改,并加载
from collections import OrderedDict #导入此模块 base_weights = torch.load(args.save_folder + args.basenet) print('Loading base network...') new_state_dict = **OrderedDict()** for k, v in base_weights.items(): name = k[4:] # remove `vgg.`,即只取vgg.0.weights的后面几位 new_state_dict[name] = v ssd_net.vgg.load_state_dict(new_state_dict)
此时就不会再出错了。
参考了这个篇。修改一下就可以应用到自己的模型啦。
//www.jb51.net/article/214214.htm
二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘
最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。
KeyError: 'layer1.0.bn1.num_batches_tracked'
其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,
这个参数的作用如下:
训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1
如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).
其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.
所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.
有问题的代码:
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) for i in state_dict: key = param_name + '.' + i state_dict[i].copy_(param_dict[key]) del param_dict
对'num_batches_tracked进行过滤:
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k} for i in state_dict: key = param_name + '.' + i if 'num_batches_tracked' in key: continue state_dict[i].copy_(param_dict[key]) del param_dict
以上为个人经验,希望能给大家一个参考,也希望大家多多支持猪先飞。
相关文章
- 这篇文章主要介绍了@CacheEvict 清除多个key的实现方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-02-13
pytorch nn.Conv2d()中的padding以及输出大小方式
今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27- 这篇文章主要介绍了@Cacheable 拼接key的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-02-13
- 这篇文章主要介绍了浅谈redis key值内存消耗以及性能影响,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-02-07
- 这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
- 下面小编就为大家带来一篇js遍历json的key和value的实例。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧...2017-01-26
Linux安装Pytorch1.8GPU(CUDA11.1)的实现
这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25- 这篇文章主要介绍了uniapp微信小程序:key失效的解决方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-01-20
- 这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
- 今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
- 这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
- 本文给大家介绍如何替换json对象中的key,通过实例代码给大家介绍key的替换方法,代码也很简单,需要的朋友参考下吧...2021-06-02
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04Pytorch 的损失函数Loss function使用详解
今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
- 这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了解决postgresql 自增id作为key重复的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-02-04