• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    从Pytorch模型pth文件中读取参数成numpy矩阵的操作

    目的:

    把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。

    Pytorch给了很方便的读取参数接口:

    nn.Module.parameters()

    直接看demo:

    from torchvision.models.alexnet import alexnet 
    model = alexnet(pretrained=True).eval().cuda()
    parameters = model.parameters()
    for p in parameters:
      numpy_para = p.detach().cpu().numpy()
      print(type(numpy_para))
      print(numpy_para.shape)

    上面得到的numpy_para就是numpy参数了~

    Note:

    model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。

    而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。

    方便又好用,爆赞~

    补充:pytorch训练好的.pth模型转换为.pt

    将python训练好的.pth文件转为.pt

    import torch
    import torchvision
    from unet import UNet
    model = UNet(3, 2)#自己定义的网络模型
    model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
    model.eval()#切换到eval()
    example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save("model.pt")

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

    您可能感兴趣的文章:
    • Numpy实现矩阵运算及线性代数应用
    • numpy数组合并和矩阵拼接的实现
    • numpy和tensorflow中的各种乘法(点乘和矩阵乘)
    • NumPy 矩阵乘法的实现示例
    • Python numpy大矩阵运算内存不足如何解决
    • 使用numpy实现矩阵的翻转(flip)与旋转
    上一篇:python 如何用urllib与服务端交互(发送和接收数据)
    下一篇:pytorch 计算Parameter和FLOP的操作
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    从Pytorch模型pth文件中读取参数成numpy矩阵的操作 从,Pytorch,模型,pth,文件,