将生成器拆分为多个块,无需预先遍历它

新手上路,请多包涵

(这个问题与 这个这个 有关,但那些是预走发电机,这正是我想要避免的)

我想将生成器分成几块。要求是:

  • 不要填充块:如果剩余元素的数量小于块大小,则最后一个块必须更小。
  • 不要事先遍历生成器:计算元素是昂贵的,它只能由消费函数完成,而不是由分块器完成
  • 这当然意味着:不要在内存中累积(没有列表)

我尝试了以下代码:

 def head(iterable, max=10):
    for cnt, el in enumerate(iterable):
        yield el
        if cnt >= max:
            break

def chunks(iterable, size=10):
    i = iter(iterable)
    while True:
        yield head(i, size)

# Sample generator: the real data is much more complex, and expensive to compute
els = xrange(7)

for n, chunk in enumerate(chunks(els, 3)):
    for el in chunk:
        print 'Chunk %3d, value %d' % (n, el)

这以某种方式起作用:

 Chunk   0, value 0
Chunk   0, value 1
Chunk   0, value 2
Chunk   1, value 3
Chunk   1, value 4
Chunk   1, value 5
Chunk   2, value 6
^CTraceback (most recent call last):
  File "xxxx.py", line 15, in <module>
    for el in chunk:
  File "xxxx.py", line 2, in head
    for cnt, el in enumerate(iterable):
KeyboardInterrupt

Buuuut … 它永远不会停止(我必须按 ^C )因为 while True 。我想在发电机被消耗时停止该循环,但我不知道如何检测这种情况。我试过引发异常:

 class NoMoreData(Exception):
    pass

def head(iterable, max=10):
    for cnt, el in enumerate(iterable):
        yield el
        if cnt >= max:
            break
    if cnt == 0 : raise NoMoreData()

def chunks(iterable, size=10):
    i = iter(iterable)
    while True:
        try:
            yield head(i, size)
        except NoMoreData:
            break

# Sample generator: the real data is much more complex, and expensive to compute
els = xrange(7)

for n, chunk in enumerate(chunks(els, 2)):
    for el in chunk:
        print 'Chunk %3d, value %d' % (n, el)

但是异常只会在消费者的上下文中引发,这不是我想要的(我想保持消费者代码干净)

 Chunk   0, value 0
Chunk   0, value 1
Chunk   0, value 2
Chunk   1, value 3
Chunk   1, value 4
Chunk   1, value 5
Chunk   2, value 6
Traceback (most recent call last):
  File "xxxx.py", line 22, in <module>
    for el in chunk:
  File "xxxx.py", line 9, in head
    if cnt == 0 : raise NoMoreData
__main__.NoMoreData()

我怎样才能检测到生成器在 chunks 函数中已耗尽,而无需运行它?

原文由 blueFast 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 657
2 个回答

一种方法是查看第一个元素(如果有),然后创建并返回实际的生成器。

 def head(iterable, max=10):
    first = next(iterable)      # raise exception when depleted
    def head_inner():
        yield first             # yield the extracted first element
        for cnt, el in enumerate(iterable):
            yield el
            if cnt + 1 >= max:  # cnt + 1 to include first
                break
    return head_inner()

只需在您的 chunk 生成器中使用它并捕获 StopIteration 异常,就像您对自定义异常所做的那样。


更新: 这是另一个版本,使用 itertools.islice 替换大部分 head 函数和 for 循环。这个简单的 for 循环实际上与原始代码中笨拙的 while-try-next-except-break 构造 _完全相同_,因此结果 更具 可读性。

 def chunks(iterable, size=10):
    iterator = iter(iterable)
    for first in iterator:    # stops when iterator is depleted
        def chunk():          # construct generator for next chunk
            yield first       # yield element from for loop
            for more in islice(iterator, size - 1):
                yield more    # yield more elements from the iterator
        yield chunk()         # in outer generator, yield next chunk

我们可以得到比这更短的,使用 itertools.chain 来替换内部生成器:

 def chunks(iterable, size=10):
    iterator = iter(iterable)
    for first in iterator:
        yield chain([first], islice(iterator, size - 1))

原文由 tobias_k 发布,翻译遵循 CC BY-SA 3.0 许可协议

另一种创建组/块而不是 预走 生成器的方法是在使用 itertools.groupby itertools.count 。由于 count 对象独立于 iterable ,因此可以在不知道 iterable 包含什么的情况下轻松生成块。

groupby 的每次迭代调用 next 对象的 count 对象的方法并生成一个组/块 _键_(然后在块中进行整数除法)当前计数值乘以块的大小。

 from itertools import groupby, count

def chunks(iterable, size=10):
    c = count()
    for _, g in groupby(iterable, lambda _: next(c)//size):
        yield g

生成器函数 产生 的每个组/块 g 是一个迭代器。但是,由于 groupby 对所有组使用共享迭代器,组迭代器不能存储在列表或任何容器中,每个组迭代器应该在下一个之前使用。

原文由 Moses Koledoye 发布,翻译遵循 CC BY-SA 3.0 许可协议

推荐问题