• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    Tensorflow 如何从checkpoint文件中加载变量名和变量值

    假设你已经经过上千次的迭代,并且得到了以下模型:

    则从这些checkpoint文件中加载变量名和变量值代码如下:

    model_dir = './ckpt-182802'
    import tensorflow as tf
    from tensorflow.python import pywrap_tensorflow
    reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
         print("tensor_name: ", key)
         print(reader.get_tensor(key)) # Remove this is you want to print only variable names
    

    Mnist

    下面将给出一个基于卷积神经网络的手写数字识别样例:

    # -*- coding: utf-8 -*-
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from tensorflow.python.framework import graph_util
    log_dir = './tensorboard'
    mnist = input_data.read_data_sets(train_dir="./mnist_data",one_hot=True)
    if tf.gfile.Exists(log_dir):
            tf.gfile.DeleteRecursively(log_dir)
    tf.gfile.MakeDirs(log_dir)
    
    #定义输入数据mnist图片大小28*28*1=784,None表示batch_size
    x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")
    #定义标签数据,mnist共10类
    y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")
    #将数据调整为二维数据,w*H*c---> 28*28*1,-1表示N张
    image = tf.reshape(x,shape=[-1,28,28,1])
    
    #第一层,卷积核={5*5*1*32},池化核={2*2*1,1*2*2*1}
    w1 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))
    b1= tf.Variable(initial_value=tf.zeros(shape=[32]))
    conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")
    relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")
    pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
    #shape={None,14,14,32}
    #第二层,卷积核={5*5*32*64},池化核={2*2*1,1*2*2*1}
    w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))
    b2 = tf.Variable(initial_value=tf.zeros(shape=[64]))
    conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")
    relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")
    pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")
    #shape={None,7,7,64}
    #FC1
    w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))
    b3 = tf.Variable(initial_value=tf.zeros(shape=[1024]))
    #关键,进行reshape
    input3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3")
    fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")
    #shape={None,1024}
    #FC2
    w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))
    b4 = tf.Variable(initial_value=tf.zeros(shape=[10]))
    fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4,name="logit")
    #shape={None,10}
    #定义交叉熵损失
    # 使用softmax将NN计算输出值表示为概率
    y = tf.nn.softmax(fc2,name="out")
    
    # 定义交叉熵损失函数
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)
    loss = tf.reduce_mean(cross_entropy)
    tf.summary.scalar('Cross_Entropy',loss)
    #定义solver
    train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)
    for var in tf.trainable_variables():
    	print var
    #train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)
    
    #定义正确值,判断二者下标index是否相等
    correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    #定义如何计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")
    tf.summary.scalar('Training_ACC',accuracy)
    #定义初始化op
    merged = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    #训练NN
    with tf.Session() as session:
        session.run(fetches=init)
        writer = tf.summary.FileWriter(log_dir,session.graph) #定义记录日志的位置
        for i in range(0,500):
            xs, ys = mnist.train.next_batch(100)
            session.run(fetches=train,feed_dict={x:xs,y_:ys})
            if i%10 == 0:
                train_accuracy,summary = session.run(fetches=[accuracy,merged],feed_dict={x:xs,y_:ys})
                writer.add_summary(summary,i)
                print(i,"accuracy=",train_accuracy)
        '''
        #训练完成后,将网络中的权值转化为常量,形成常量graph,注意:需要x与label
        constant_graph = graph_util.convert_variables_to_constants(sess=session,
                                                                input_graph_def=session.graph_def,
                                                                output_node_names=['out','y_','input'])
        #将带权值的graph序列化,写成pb文件存储起来
        with tf.gfile.FastGFile("lenet.pb", mode='wb') as f:
            f.write(constant_graph.SerializeToString())
        '''
        saver.save(session,'./ckpt')
    
    

    补充:查看tensorflow产生的checkpoint文件内容的方法

    tensorflow在保存权重模型时多使用tf.train.Saver().save 函数进行权重保存,保存的ckpt文件无法直接打开,但tensorflow提供了相关函数 tf.train.NewCheckpointReader 可以对ckpt文件进行权重查看。

    import os
    from tensorflow.python import pywrap_tensorflow
    
    checkpoint_path = os.path.join('modelckpt', "fc_nn_model")
    # Read data from checkpoint file
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    # Print tensor name and values
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key))
    

    其中‘modelckpt'是存放.ckpt文件的文件夹,"fc_nn_model"是文件名,如下图所示。

     

    var_to_shape_map是一个字典,其中的键值是变量名,对应的值是该变量的形状,如

    {‘LSTM_input/bias_LSTM/Adam_1': [128]}

    想要查看某变量值时,需要调用get_tensor函数,即输入以下代码:

    reader.get_tensor('LSTM_input/bias_LSTM/Adam_1')
    

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 使用tensorflow 实现反向传播求导
    • TensorFlow的自动求导原理分析
    • tensorflow中的梯度求解及梯度裁剪操作
    • Python3安装tensorflow及配置过程
    • 解决tensorflow 与keras 混用之坑
    • tensorflow中的数据类型dtype用法说明
    上一篇:python基础学习之组织文件
    下一篇:pytorch 实现变分自动编码器的操作
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    Tensorflow 如何从checkpoint文件中加载变量名和变量值 Tensorflow,如何,从,checkpoint,