• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch加载预训练模型与自己模型不匹配的解决方案

    pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

    两个有序字典找不同

    模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

    model = ResNet18(1)
    model_dict1 = torch.load('resnet18.pth')
    model_dict2 = model.state_dict()
    model_list1 = list(model_dict1.keys())
    model_list2 = list(model_dict2.keys())
    len1 = len(model_list1)
    len2 = len(model_list2)
    minlen = min(len1, len2)
    for n in range(minlen):
        if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
            err = 1

    自己搭建模型的注意事项

    搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

    model = ResNet18(1)
    model_dict1 = torch.load('resnet18.pth')
    model_dict2 = model.state_dict()
    model_list1 = list(model_dict1.keys())
    model_list2 = list(model_dict2.keys())
    len1 = len(model_list1)
    len2 = len(model_list2)
    minlen = min(len1, len2)
    for n in range(minlen):
        if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
            continue
        model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
    model.load_state_dict(model_dict2)

    完整的代码见自己搭建resnet18网络并加载torchvision自带权重

    新增的改进代码

    model_dict1 = torch.load('yolov5.pth')
    model_dict2 = model.state_dict()
    model_list1 = list(model_dict1.keys())
    model_list2 = list(model_dict2.keys())
    len1 = len(model_list1)
    len2 = len(model_list2)
    m, n = 0, 0
    while True:
        if m >= len1 or n >= len2:
            break
        layername1, layername2 = model_list1[m], model_list2[n]
        w1, w2 = model_dict1[layername1], model_dict2[layername2]
        if w1.shape != w2.shape:
            continue
        model_dict2[layername2] = model_dict1[layername1]
        m += 1
        n += 1
    model.load_state_dict(model_dict2)

    如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

    补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

    看代码吧~

    #打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
    #以及到第二个全连接层的全部网络还有他们对应的参数
    class Classification_att(nn.Module):
        def __init__(self, rgb_range):
            super(Classification_att, self).__init__()
            self.vgg19 =models.vgg19(pretrained=True)
            vgg = models.vgg19(pretrained=True).features
            conv_modules = [m for m in vgg]
            self.vgg_conv = nn.Sequential(*conv_modules[:37])
            classfi = models.vgg19(pretrained=True).classifier
            classif_modules = [n for n in classfi]
            self.vgg_class = nn.Sequential(*classif_modules[:4])
            vgg_mean = (0.485, 0.456, 0.406)
            vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
            self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
            for p in self.vgg_conv.parameters():
                p.requires_grad = False
            for p in self.vgg_class.parameters():
                p.requires_grad = False
            self.classifi = nn.Sequential(
                nn.Linear(4096, 1024),
                nn.ReLU(True),
                nn.Linear(1024, 256),
                nn.ReLU(True),
                nn.Linear(256, 64),
            )
     
        def forward(self, x):
            x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
            align_corners=False)
            x = self.sub_mean(x)
            x = self.vgg_conv(x)  
            x = self.vgg_class(x)  #执行这部报错,说张量不匹配

    原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

    查看vgg的pytorch源码发现是

    x = self.features(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    #自己的代码没有torch.flatten(x, 1)这步
    

    所以自己的少了一步

    x = torch.flatten(x, 1)

    补上就好了!

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 解决Pytorch 加载训练好的模型 遇到的error问题
    • pytorch 更改预训练模型网络结构的方法
    • 解决Pytorch修改预训练模型时遇到key不匹配的情况
    上一篇:Python数据分析入门之教你怎么搭建环境
    下一篇:python执行js代码的方法
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch加载预训练模型与自己模型不匹配的解决方案 pytorch,加载,预,训练,模型,