• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    Tensorflow与RNN、双向LSTM等的踩坑记录及解决

    1、tensorflow(不定长)文本序列读取与解析

    tensorflow读取csv时需要指定各列的数据类型。

    但是对于RNN这种接受序列输入的模型来说,一条序列的长度是不固定。这时如果使用csv存储序列数据,应当首先将特征序列拼接成一列。

    例如两条数据序列,第一项是标签,之后是特征序列

    [0, 1.1, 1.2, 2.3] 转换成 [0, '1.1_1.2_2.3']

    [1, 1.0, 2.5, 1.6, 3.2, 4.5] 转换成 [1, '1.0_2.5_1.6_3.2_4.5']

    这样每条数据都只包含固定两列了。

    读取方式是指定第二列为字符串类型,再将字符串按照'_'分割并转换为数字。

    关键的几行代码示例如下:

    def readMyFileFormat(fileNameQueue):
        reader = tf.TextLineReader()
        key, value = reader.read(fileNameQueue)
    
        record_defaults = [["Null"], [-1], ["Null"], ["Null"], [-1]]
        phone1, seqlen, ts_diff_strseq, t_cod_strseq, userlabel = tf.decode_csv(value, record_defaults=record_defaults)
        ts_diff_str = tf.string_split([ts_diff_strseq], delimiter='_')
        t_cod_str = tf.string_split([t_cod_strseq], delimiter='_')
        # 每个字符串转数字
        Str2Float = lambda string: tf.string_to_number(string, tf.float32)
        Str2Int = lambda string: tf.string_to_number(string, tf.int32)
        ts_diff_seq = tf.map_fn(Str2Float, ts_diff_str.values, dtype = tf.float32) # 一定要加上dtype,且必须与fn的输出类型一致
        t_cod_seq = tf.map_fn(Str2Int, t_cod_str.values, dtype = tf.int32)
    

    2、时序建模的序列预测、序列拟合、标签预测,及输入数据格式

    序列预测、拟合的“标签”都是序列本身,区别是未来时刻或者是当前时刻,当前时刻的拟合任务类似于antoencoder的reconstruction

    标签预测常见于语言学建模,有单词级标签的分词与整句标签的情感分析,前者需要对每一个单词输入都要输出其分词标识,后者是取最后若干输出级联前馈神经网络分类器

    keras的输入-输出对:需要将序列拆分成多个片段

    序列形式:

    按时间列表:static_bidirectional_rnn

    多维数组:bidirectional_dynamic_rnn与stack_bidirectional_dynamic_rnn 变长双向rnn的正确使用姿势

    3、多任务设置及相应的输出向量划分

    对于标签预测任务,按需取输出即可

    对于序列预测、拟合:

    双向lstm:通常用于拟合。但如果需要捕捉动态信息,尽管需要序列完整输入,则仍可以加上正向预测与反向预测

    单向lstm:拟合与预测

    4、zero padding

    后一般需要通过tf.boolean_mask()隔离这些零的影响,函数输入包括数据矩阵和补零位置的指示矩阵。

    5、get_shape()方法

    与 tf.shape() 类型区别,前者得到一个list,后者得到一个tensor

    6、双向LSTM的信息瓶颈的解决

    如果在时间步的最后输出,则可能会导致开始的一些字符被遗忘门给遗忘。

    所以这里就对每个时间步的输出做出了处理,

    主要处理有:

    1、拼接:把所有的输出拼接在一起。

    2、Average

    3、Pooling

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • 教你使用TensorFlow2识别验证码
    • pytorch_pretrained_bert如何将tensorflow模型转化为pytorch模型
    • TensorFlow中tf.batch_matmul()的用法
    • tensorflow中的数据类型dtype用法说明
    • tensorflow基本操作小白快速构建线性回归和分类模型
    上一篇:Python数据类型最全知识总结
    下一篇:如何在Python项目中引入日志
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    Tensorflow与RNN、双向LSTM等的踩坑记录及解决 Tensorflow,与,RNN,双向,LSTM,