• 企业400电话
  • 微网小程序
  • AI电话机器人
  • 电商代运营
  • 全 部 栏 目

    企业400电话 网络优化推广 AI电话机器人 呼叫中心 网站建设 商标✡知产 微网小程序 电商运营 彩铃•短信 增值拓展业务
    基于keras中训练数据的几种方式对比(fit和fit_generator)

    一、train_on_batch

    model.train_on_batch(batchX, batchY)

    train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型,大部分情况下我们不需要这么精细,99%情况下使用fit_generator训练方式即可,下面会介绍。

    二、fit

    model.fit(x_train, y_train, batch_size=32, epochs=10)
    

    fit的方式是一次把训练数据全部加载到内存中,然后每次批处理batch_size个数据来更新模型参数,epochs就不用多介绍了。这种训练方式只适合训练数据量比较小的情况下使用。

    三、fit_generator

    利用Python的生成器,逐个生成数据的batch并进行训练,不占用大量内存,同时生成器与模型将并行执行以提高效率。例如,该函数允许我们在CPU上进行实时的数据提升,同时在GPU上进行模型训练

    接口如下:

    fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

    generator:生成器函数

    steps_per_epoch:整数,当生成器返回steps_per_epoch次数据时,计一个epoch结束,执行下一个epoch。也就是一个epoch下执行多少次batch_size。

    epochs:整数,控制数据迭代的轮数,到了就结束训练。

    callbacks=None, list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数

    举例:

    def generate_arrays_from_file(path):
                while True:
                    with open(path) as f:
                        for line in f:
                            # create numpy arrays of input data
                            # and labels, from each line in the file
                            x1, x2, y = process_line(line)
                            yield ({'input_1': x1, 'input_2': x2}, {'output': y})
     
    model.fit_generator(generate_arrays_from_file('./my_folder'),
                                steps_per_epoch=10000, epochs=10)

    补充:keras.fit_generator()属性及取值

    如下所示:

    fit_generator(self, generator, 
                        steps_per_epoch=None, 
                        epochs=1, 
                        verbose=1, 
                        callbacks=None, 
                        validation_data=None, 
                        validation_steps=None,  
                        class_weight=None,
                        max_queue_size=10,   
                        workers=1, 
                        use_multiprocessing=False, 
                        shuffle=True, 
                        initial_epoch=0)

    通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。

    参数:

    generator:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。

    steps_per_epoch:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。

    epochs:整数,在数据集上迭代的总数。

    works:在使用基于进程的线程时,最多需要启动的进程数量。

    use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。

    例如:

    datagen = ImageDataGenator(...)
    model.fit_generator(datagen.flow(x_train, y_train,
                                     batch_size=batch_size),
                        epochs=epochs,
                        validation_data=(x_test, y_test),
                        workers=4)

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    您可能感兴趣的文章:
    • keras修改backend的简单方法
    • keras的get_value运行越来越慢的解决方案
    • 浅谈Keras中fit()和fit_generator()的区别及其参数的坑
    • Keras保存模型并载入模型继续训练的实现
    • TensorFlow2.0使用keras训练模型的实现
    • tensorflow2.0教程之Keras快速入门
    • 浅析关于Keras的安装(pycharm)和初步理解
    • 基于Keras的扩展性使用
    上一篇:python一秒搭建FTP服务器
    下一篇:python实现某考试系统生成word试卷
  • 相关文章
  • 

    © 2016-2020 巨人网络通讯 版权所有

    《增值电信业务经营许可证》 苏ICP备15040257号-8

    基于keras中训练数据的几种方式对比(fit和fit_generator) 基于,keras,中,训练,数据,