from IPython.core.display import HTML

HTML(open("custom.html", "r").read())
Creative Commons License This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

Copyright (C) 2014-2023 Scientific IT Services of ETH Zurich,
Contributing Authors: Uwe Schmitt, Mikolaj Rybniski

12. Generator expressions and yield:¶

When we replace the square brackets of a list comprehension by round brackets, we get a so called generator expression:

li = (i * i for i in range(99999) if i % 7 == 0)
print(li)
<generator object <genexpr> at 0x107502260>

Such a generator expressions can be considered as a "lazy" list comprehension. This means the list is not created in memory, instead elements are created on demand when we iterate over this generator:

for i, element in enumerate(li):
    print(element)
    if i == 3:
        break
0
49
196
441

This generator can be continued:

for i, element in enumerate(li):
    if i <= 3:
        print(element)
   
784
1225
1764
2401

And after one pass the iterator is "exhausted":

for element in li:
    print(element)
print("done")
done

Such iterators can save a lot of memory compared to lists and can be used to declare a data processing pipeline with little memory overhead in a clear way:

even_numbers = (i for i in range(1_000_000) if i % 2 == 0)
squared = (i ** 2 for i in even_numbers)

first_ten = (next(squared) for _ in range(10))  # next(.) returns next element in iteration
print(first_ten)
<generator object <genexpr> at 0x107502810>
print(list(first_ten))
[0, 4, 16, 36, 64, 100, 144, 196, 256, 324]

The keyword yield in the following example declares a generator, not a function:

def numbers():
    print("a")
    yield 1
    print("b")
    yield 2
    print("c")

Calling numbers() does not work like a function call, but returns a generator:

print(numbers())
<generator object numbers at 0x107502ab0>

As you can see calling such a generator does not execute the function body but returns a genertor object, similar to what we have seen above.

Such a generator can be used like an iterator:

for x in numbers():
    print("x is", x)
a
x is 1
b
x is 2
c
iterator = numbers()
print(next(iterator))
a
1

So calling numbers() does not run any statement of the code after def numbers(). The first iteration prints a and yields 1. The second iteration prints b and yields 2 the next iteration prints c and as the code block ends iteration stops.

print(next(iterator))
b
2
print(next(iterator))
c
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[13], line 1
----> 1 print(next(iterator))

StopIteration: 
def numbers():
    # infinite generator!!!!
    i = 0
    while True:
        yield i
        i += 1


def blocks_of_size(n, iterator):
    while True:
        block = [next(iterator) for _ in range(n)]
        yield block


def average(block_iterator):
    for block in block_iterator:
        yield sum(block) / len(block)


pipeline = average(blocks_of_size(5, numbers()))
for _ in range(5):
    print(next(pipeline))
2.0
7.0
12.0
17.0
22.0

The standard library containes a module intertools which offers a rich set of predefined iterators and tools to work with iterators.

import itertools

# infinite counting iterator:
c = itertools.count()
print(next(c))
print(next(c))
print(next(c))
0
1
2
c = itertools.cycle(range(3))
print([next(c) for _ in range(10)])
[0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
c = itertools.product(range(2), range(3))
for item in c:
    print(item)
(0, 0)
(0, 1)
(0, 2)
(1, 0)
(1, 1)
(1, 2)
c = itertools.zip_longest(range(4), range(3), "ab")

for item in c:
    print(item)
(0, 0, 'a')
(1, 1, 'b')
(2, 2, None)
(3, None, None)

See also http://pymotw.com/3/itertools/index.html

Exercise section*¶

Repeat the examples and play with them.

Optional exercise¶

Write a function chain which takes an arbitrary number of iterators and iterates over all of them, one after each other. E.g.

for value in chain([1, 2, 3], (i**2 for i in range(3), (7, 8)):
    print(value, end=" ")

prints

1 2 3 1 4 9 7 8