为 Keras model.fit_generator 使用生成器

新手上路,请多包涵

我最初在编写用于训练 Keras 模型的自定义生成器时尝试使用 generator 语法。所以我 yield__next__ --- 编辑。但是,当我尝试使用 model.fit_generator 训练我的模式时,我会收到一条错误消息,提示我的生成器不是迭代器。解决方法是将 yield 更改为 return 这也需要重新调整 __next__ 的逻辑以跟踪状态。与让 yield 为我完成工作相比,这非常麻烦。

有什么方法可以让我使用 yield 吗?如果我必须使用 return 语句,我将需要编写更多的迭代器,这些迭代器必须具有非常笨拙的逻辑。

原文由 doogFromMT 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 893
1 个回答

我无法帮助调试你的代码,因为你没有发布它,但我缩写了我为语义分割项目编写的自定义数据生成器,供你用作模板:

 def generate_data(directory, batch_size):
    """Replaces Keras' native ImageDataGenerator."""
    i = 0
    file_list = os.listdir(directory)
    while True:
        image_batch = []
        for b in range(batch_size):
            if i == len(file_list):
                i = 0
                random.shuffle(file_list)
            sample = file_list[i]
            i += 1
            image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)
            image_batch.append((image.astype(float) - 128) / 128)

        yield np.array(image_batch)

用法:

 model.fit_generator(
    generate_data('~/my_data', batch_size),
    steps_per_epoch=len(os.listdir('~/my_data')) // batch_size)

原文由 Jessica Alan 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进