• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    解决Pytorch修改预训练模型时遇到key不匹配的情况

    一、Pytorch修改预训练模型时遇到key不匹配

    最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

    在我使用新赋值的网络模型时出现了key不匹配的问题

    #加载后保存(未修改网络)
    base_weights = torch.load(args.save_folder + args.basenet)
    ssd_net.vgg.load_state_dict(base_weights) 
    torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
    
    # 将新保存的网络代替之前的预训练模型
        ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
        net = ssd_net
        ...
        if args.resume:
            ...
        else:
            base_weights = torch.load(args.save_folder + args.basenet)
            #args.basenet为ssd_base.pth
            print('Loading base network...')
            ssd_net.vgg.load_state_dict(base_weights) 
    

    此时会如下出错误:

    Loading base network…
    Traceback (most recent call last):
    File “train.py”, line 264, in
    train()
    File “train.py”, line 110, in train
    ssd_net.vgg.load_state_dict(base_weights)

    RuntimeError: Error(s) in loading state_dict for ModuleList:
    Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
    Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

    说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

    我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

    现在的问题是因为自己定义保存的模型key参数多了一个前缀。

    可以通过如下语句进行修改,并加载

    from collections import OrderedDict   #导入此模块
    base_weights = torch.load(args.save_folder + args.basenet)
    print('Loading base network...')
    new_state_dict = **OrderedDict()**  
    for k, v in base_weights.items():
        name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
        new_state_dict[name] = v 
        ssd_net.vgg.load_state_dict(new_state_dict) 

    此时就不会再出错了。

    参考了这个篇。修改一下就可以应用到自己的模型啦。

    //www.jb51.net/article/214214.htm

    二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

    最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

    KeyError: 'layer1.0.bn1.num_batches_tracked'

    其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

    这个参数的作用如下:

    训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

    如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

    其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

    所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

    有问题的代码:

       def load_specific_param(self, state_dict, param_name, model_path):
            param_dict = torch.load(model_path)
            for i in state_dict:
                key = param_name + '.' + i
                state_dict[i].copy_(param_dict[key])
            del param_dict

    对'num_batches_tracked进行过滤:

       def load_specific_param(self, state_dict, param_name, model_path):
            param_dict = torch.load(model_path)
            param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
            for i in state_dict:
                key = param_name + '.' + i
                if 'num_batches_tracked' in key:
                    continue
                state_dict[i].copy_(param_dict[key])
            del param_dict

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • Pytorch通过保存为ONNX模型转TensorRT5的实现
    • pytorch_pretrained_bert如何将tensorflow模型转化为pytorch模型
    • pytorch模型的保存和加载、checkpoint操作
    • PyTorch 如何检查模型梯度是否可导
    • pytorch 预训练模型读取修改相关参数的填坑问题
    • PyTorch模型转TensorRT是怎么实现的?
    上一篇:pytorch 预训练模型读取修改相关参数的填坑问题
    下一篇:python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    解决Pytorch修改预训练模型时遇到key不匹配的情况 解决,Pytorch,修改,预,训练,