• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    keras的get_value运行越来越慢的解决方案

    keras 深度学习框架中get_value函数运行越来越慢,内存消耗越来越大问题

    问题描述

    如上图所示,经过时间和内存消耗跟踪测试,发现是keras.backend.get_value() 函数导致的程序越来越慢,而且严重的造成内存泄露;

    查看该函数内部实现,发现一个主要核心是x.eval(session=get_session()),该语句可能是导致内存泄露和运行慢的核心语句; 根据查看一些博文得到了运行得越来越慢的

    原因该x.eval函数会添加新的节点到tf的图中;而这也导致了tf的图越来越大,内存泄露;

    解决方法

    import tensorflow.keras.backend as K
    
    def get_my_session(gpu_fraction=0.1):
        '''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''
    
        num_threads = os.environ.get('OMP_NUM_THREADS')
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
    
        if num_threads:
            return tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
        else:
            return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    
    K.set_session(get_my_session())

    如上图所示, 我在使用tensorflow之前(也就是该工程文件前面),对session进行自定义,然后用自定义的session设定keras.backend.set_session();

    然后删除get_value() 函数,直接用get_value()中所使用的执行语句x.eval(session=get_my_session());这样这个添加节点导致内存泄露的核心语句x.eval()就使用的是该工程统一自定义session,然后用tf.reset_default_graph() 对图重置就可以了

    即上图问题代码修改为:

    output = ctc_decode(y_pred,input_length=input_length,)
    output = output[0][0]
    out = output.eval(session=get_my_session())
    # 删除 K.get_value(out[0][0])
    tf.reset_default_graph() # 然后重置tf图,这句很关键

    这样就解决了get_value()导致的越来越慢的问题;

    个人认为:这样可能就不会总是添加新的节点,导致tf图不断地无限变大;而是重复使用这一个自定义的节点。

    补充:tensorflow与keras之间版本问题引起get_session问题解决办法

    1.产生报错原因

    import tensorflow.keras.backend as K
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults) # set up default values
        self.__dict__.update(kwargs) # and update with user overrides
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()

    报错如下:

    get_session is not available when using TensorFlow 2.0.

    意思是 tf2.0 没有 get_session

    2.解决方案1

    import tensorflow.python.keras.backend as K
    sess = K.get_session()

    3. 解决方案2

    import tensorflow as tf
    sess = tf.compat.v1.keras.backend.get_session()

    之前一直采用方案1 解决,感觉比较方便;但是解决方案1 有其它属性会丢失问题

    比如AttributeError: module ‘keras.backend' has no attribute image_dim_ordering

    所以建议大家采用方案2

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • keras修改backend的简单方法
    • 基于keras中训练数据的几种方式对比(fit和fit_generator)
    • 浅谈Keras中fit()和fit_generator()的区别及其参数的坑
    • Keras保存模型并载入模型继续训练的实现
    • TensorFlow2.0使用keras训练模型的实现
    • tensorflow2.0教程之Keras快速入门
    • 浅析关于Keras的安装(pycharm)和初步理解
    • 基于Keras的扩展性使用
    上一篇:windowns使用PySpark环境配置和基本操作
    下一篇:pyspark创建DataFrame的几种方法
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    keras的get_value运行越来越慢的解决方案 keras,的,get,value,运行,越来,