• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    Pytorch训练网络过程中loss突然变为0的解决方案

    问题

    // loss 突然变成0
    python train.py -b=8
    INFO: Using device cpu
    INFO: Network:
            1 input channels
            7 output channels (classes)
            Bilinear upscaling
    INFO: Creating dataset with 868 examples
    INFO: Starting training:
            Epochs:          5
            Batch size:      8
            Learning rate:   0.001
            Training size:   782
            Validation size: 86
            Checkpoints:     True
            Device:          cpu
            Images scaling:  1
        
    Epoch 1/5:  10%|██████████████▏                                                                                                                            | 80/782 [01:3313:21,  1.14s/img, loss (batch)=0.886I
    NFO: Validation cross entropy: 1.86862473487854                                                                                                                                                                  
    Epoch 1/5:  20%|███████████████████████████▊                                                                                                            | 160/782 [03:3411:51,  1.14s/img, loss (batch)=2.35e-7I
    NFO: Validation cross entropy: 5.887489884504049e-10                                                                                                                                                             
    Epoch 1/5:  31%|███████████████████████████████████████████▌                                                                                                  | 240/782 [05:4111:29,  1.27s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  41%|██████████████████████████████████████████████████████████                                                                                    | 320/782 [07:4909:16,  1.20s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  51%|████████████████████████████████████████████████████████████████████████▋                                                                     | 400/782 [09:5507:31,  1.18s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  61%|███████████████████████████████████████████████████████████████████████████████████████▏                                                      | 480/782 [12:0205:58,  1.19s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  72%|█████████████████████████████████████████████████████████████████████████████████████████████████████▋                                        | 560/782 [14:0404:16,  1.15s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  82%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                         | 640/782 [16:1102:49,  1.20s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋           | 720/782 [18:2101:18,  1.26s/img, loss (batch)=0I
    NFO: Validation cross entropy: 0.0                                                                                                                                                                               
    Epoch 1/5:  94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋        | 736/782 [19:1701:12,  1.57s/img, loss (batch)=0]
    Traceback (most recent call last):
      File "train.py", line 182, in module>
        val_percent=args.val / 100)
      File "train.py", line 66, in train_net
        for batch in train_loader:
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 819, in __next__
        return self._process_data(data)
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 846, in _process_data
        data.reraise()
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/_utils.py", line 385, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in DataLoader worker process 4.
    Original Traceback (most recent call last):
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
        return {key: default_collate([d[key] for d in batch]) for key in elem}
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in dictcomp>
        return {key: default_collate([d[key] for d in batch]) for key in elem}
      File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: Expected object of scalar type Double but got scalar type Byte for sequence element 4 in sequence argument at position #1 'tensors'
    

    交叉熵损失函数是衡量输出与标签之间的损失,通过求导确定梯度下降的方向。

    loss突然变为0,有两种可能性。

    一是因为预测输出为0,二是因为标签为0。

    如果是因为标签为0,那么一开始loss就可能为0.

    检查参数初始化

    检查前向传播的网络

    检查loss的计算格式

    检查梯度下降

    是否出现梯度消失。

    实际上是标签出了错误

    补充:pytorch训练出现loss=na

    遇到一个很坑的情况,在pytorch训练过程中出现loss=nan的情况

    有以下几种可能:

    1.学习率太高。

    2.loss函数有问题

    3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决

    4.数据本身,是否存在Nan、inf,可以用np.isnan(),np.isinf()检查一下input和target

    5.target本身应该是能够被loss函数计算的,比如sigmoid激活函数的target应该大于0,同样的需要检查数据集

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

    您可能感兴趣的文章:
    • 解决Pytorch半精度浮点型网络训练的问题
    • PyTorch梯度裁剪避免训练loss nan的操作
    • pytorch训练神经网络爆内存的解决方案
    • Pytorch训练模型得到输出后计算F1-Score 和AUC的操作
    • pytorch加载预训练模型与自己模型不匹配的解决方案
    • pytorch 如何使用float64训练
    上一篇:python 如何把classification_report输出到csv文件
    下一篇:pytorch中常用的损失函数用法说明
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    Pytorch训练网络过程中loss突然变为0的解决方案 Pytorch,训练,网络,过程中,