• 企业400电话
  • 网络优化推广
  • AI电话机器人
  • 呼叫中心
  • 全 部 栏 目

    网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch 实现在测试的时候启用dropout
    POST TIME:2021-10-18 13:45

    我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

    在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

    想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

    def apply_dropout(m):
        if type(m) == nn.Dropout:
            m.train()

    下面是完整demo代码:

    # coding: utf-8
    import torch
    import torch.nn as nn
    import numpy as np
    class SimpleNet(nn.Module):
        def __init__(self):
            super(SimpleNet, self).__init__()
            self.fc = nn.Linear(8, 8)
            self.dropout = nn.Dropout(0.5)
        def forward(self, x):
            x = self.fc(x)
            x = self.dropout(x)
            return x
    net = SimpleNet()
    x = torch.FloatTensor([1]*8)
    net.train()
    y = net(x)
    print('train mode result: ', y)
    net.eval()
    y = net(x)
    print('eval mode result: ', y)
    net.eval()
    y = net(x)
    print('eval2 mode result: ', y)
    def apply_dropout(m):
        if type(m) == nn.Dropout:
            m.train()
    net.eval()
    net.apply(apply_dropout)
    y = net(x)
    print('apply eval result:', y)
    

    运行结果:

    可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

    补充:Pytorch之dropout避免过拟合测试

    一.做数据

    二.搭建神经网络

    三.训练

    四.对比测试结果

    注意:测试过程中,一定要注意模式切换

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 浅谈pytorch中的dropout的概率p
    • PyTorch 实现L2正则化以及Dropout的操作
    • Python深度学习pytorch神经网络Dropout应用详解解
    上一篇:使用Python脚本对GiteePages进行一键部署的使用说明
    下一篇:Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
  • 相关文章
  • 

    关于我们 | 付款方式 | 荣誉资质 | 业务提交 | 代理合作


    © 2016-2020 巨人网络通讯

    时间:9:00-21:00 (节假日不休)

    地址:江苏信息产业基地11号楼四层

    《增值电信业务经营许可证》 苏B2-20120278

    X

    截屏,微信识别二维码

    微信号:veteran88

    (点击微信号复制,添加好友)

     打开微信