• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch交叉熵损失函数的weight参数的使用

    首先

    必须将权重也转为Tensor的cuda格式;

    然后

    将该class_weight作为交叉熵函数对应参数的输入值。

    class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

    补充:关于pytorch的CrossEntropyLoss的weight参数

    首先这个weight参数比想象中的要考虑的多

    你可以试试下面代码

    import torch
    import torch.nn as nn
    inputs = torch.FloatTensor([0,1,0,0,0,1])
    outputs = torch.LongTensor([0,1])
    inputs = inputs.view((1,3,2))
    outputs = outputs.view((1,2))
    weight_CE = torch.FloatTensor([1,1,1])
    ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
    loss = ce(inputs,outputs)
    print(loss)
    tensor(1.4803)

    这里的手动计算是:

    loss1 = 0 + ln(e0 + e0 + e0) = 1.098

    loss2 = 0 + ln(e1 + e0 + e1) = 1.86

    求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

    加权呢?

    import torch
    import torch.nn as nn
    inputs = torch.FloatTensor([0,1,0,0,0,1])
    outputs = torch.LongTensor([0,1])
    inputs = inputs.view((1,3,2))
    outputs = outputs.view((1,2))
    weight_CE = torch.FloatTensor([1,2,3])
    ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
    loss = ce(inputs,outputs)
    print(loss)
    
    tensor(1.6075)

    手算发现,并不是单纯的那权重相乘:

    loss1 = 0 + ln(e0 + e0 + e0) = 1.098

    loss2 = 0 + ln(e1 + e0 + e1) = 1.86

    求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

    而是

    loss1 = 0 + ln(e0 + e0 + e0) = 1.098

    loss2 = 0 + ln(e1 + e0 + e1) = 1.86

    求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

    发现了么,加权后,除以的是权重的和,不是数目的和。

    我们再验证一遍:

    import torch
    import torch.nn as nn
    inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
    outputs = torch.LongTensor([0,1,2,2])
    inputs = inputs.view((1,3,4))
    outputs = outputs.view((1,4))
    weight_CE = torch.FloatTensor([1,2,3])
    ce = nn.CrossEntropyLoss(weight=weight_CE)
    # ce = nn.CrossEntropyLoss(ignore_index=255)
    loss = ce(inputs,outputs)
    print(loss)
    
    tensor(1.5472)

    手算:

    loss1 = 0 + ln(e0 + e0 + e0) = 1.098

    loss2 = 0 + ln(e1 + e0 + e1) = 1.86

    loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

    loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

    求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

    可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • PyTorch的SoftMax交叉熵损失和梯度用法
    • pytorch中常用的损失函数用法说明
    • Pytorch十九种损失函数的使用详解
    • pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
    • Python机器学习pytorch交叉熵损失函数的深刻理解
    上一篇:pytorch 实现变分自动编码器的操作
    下一篇:pytorch 实现二分类交叉熵逆样本频率权重
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch交叉熵损失函数的weight参数的使用 pytorch,交叉,熵,损失,函数,