• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    pytorch实现线性回归

    pytorch实现线性回归代码练习实例,供大家参考,具体内容如下

    欢迎大家指正,希望可以通过小的练习提升对于pytorch的掌握

    # 随机初始化一个二维数据集,使用朋友torch训练一个回归模型
    import numpy as np
    import random
    import matplotlib.pyplot as plt
    
    x = np.arange(20)
    y = np.array([5*x[i] + random.randint(1,20) for i in range(len(x))])    # random.randint(参数1,参数2)函数返回参数1和参数2之间的任意整数
    print('-'*50)
    # 打印数据集
    print(x)
    print(y)
    
    import torch
    x_train = torch.from_numpy(x).float()
    y_train = torch.from_numpy(y).float()
    
    # model
    class LinearRegression(torch.nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            # 输入与输出都是一维的
            self.linear = torch.nn.Linear(1,1)
        def forward(self,x):
            return self.linear(x)
    
    # 新建模型,误差函数,优化器
    model = LinearRegression()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(),0.001)
    # 开始训练
    num_epoch = 20
    for i in range(num_epoch):
        input_data = x_train.unsqueeze(1)
        target = y_train.unsqueeze(1)           # unsqueeze(1)在第二维增加一个维度
        out = model(input_data)
        loss = criterion(out,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Eopch:[{}/{},loss:[{:.4f}]".format(i+1,num_epoch,loss.item()))
        if ((i+1)%2 == 0):
            predict = model(input_data)
            plt.plot(x_train.data.numpy(),predict.squeeze(1).data.numpy(),"r")
            loss = criterion(predict,target)
            plt.title("Loss:{:.4f}".format(loss.item()))
            plt.xlabel("X")
            plt.ylabel("Y")
            plt.scatter(x_train,y_train)
            plt.show()

    实验结果:

    以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • python深度总结线性回归
    • tensorflow基本操作小白快速构建线性回归和分类模型
    • 回归预测分析python数据化运营线性回归总结
    • python实现线性回归算法
    • python机器学习之线性回归详解
    • 使用pytorch实现线性回归
    • 详解TensorFlow2实现前向传播
    上一篇:python tkinter 获得按钮的文本值
    下一篇:python如何获取网络数据
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    pytorch实现线性回归 pytorch,实现,线性,回归,pytorch,