• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch 6 batch_train 批训练操作

    看代码吧~

    import torch
    import torch.utils.data as Data
    torch.manual_seed(1)    # reproducible
    # BATCH_SIZE = 5  
    BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
    x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
    torch_dataset = Data.TensorDataset(x, y)
    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=False,              # 设置不随机打乱数据 random shuffle for training
        num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
    )
    def show_batch():
        for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
            for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
                # train your data...
                print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                      batch_x.numpy(), '| batch y: ', batch_y.numpy())
    if __name__ == '__main__':
        show_batch()
    

    BATCH_SIZE = 8 , 所有数据利用三次

    Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

    补充:pytorch批训练bug

    问题描述:

    在进行pytorch神经网络批训练的时候,有时会出现报错 

    TypeError: batch must contain tensors, numbers, dicts or lists; found class 'torch.autograd.variable.Variable'>

    解决办法:

    第一步:

    检查(重点!!!!!):

    train_dataset = Data.TensorDataset(train_x, train_y)

    train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

    可以这样将数据变为tensor类:

    train_x = torch.FloatTensor(train_x)

    第二步:

    train_loader = Data.DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True
        )

    实例化一个DataLoader对象

    第三步:

        for epoch in range(epochs):
            for step, (batch_x, batch_y) in enumerate(train_loader):
                batch_x, batch_y = Variable(batch_x), Variable(batch_y)

    这样就可以批训练了

    需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 详解PyTorch批训练及优化器比较
    • pytorch 固定部分参数训练的方法
    • pytorch 准备、训练和测试自己的图片数据的方法
    • pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
    上一篇:Python函数装饰器的使用教程
    下一篇:Keras多线程机制与flask多线程冲突的解决方案
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch 6 batch_train 批训练操作 pytorch,batch,train,批,训练,