第一章 培养Pythonic思维

第1条 查询自己使用的Python版本

1
$ python3 --version

不要使用Python2!

推荐3.8左右的版本(写于2023-04),初学者不要用太新的版本,很多库不支持最新版。

第2条 遵循PEP8风格指南

PEP8全称Python Enhancement Proposal #8,是针对Python代码格式编订的风格指南。不符合PEP8的做法一般编译器(如Pycharm)会自动检查标黄提醒。以下是一些重点规范:

  • 用4个空格表示缩进,而不是制表符
  • 下划线变量命名。
  • 表达式一行写不下,可以用括号括起来,不要用 \ 续行。
  • import语句放在开头,总是使用绝对名称,按顺序划为三部分:标准库、第三方、自己的模块。

第3条 了解bytes与str区别

bytes包含的是原始数据,即8位无符号值

1
a = b'h\x65llo'

str包含的是Unicode码点

1
a = 'a\u0300 propos'

第4条 用支持插值的f-string取代C风格的格式字符串与str.format方法

1
2
3
4
5
6
a = 'tom'
b = 'jerry'

print('%s and %s' % (a, b)) # C风格的格式字符串
print('{0} and {1}'.format(a, b)) # str.format方法
print(f'{a} and {b}') # 支持插值的f-string

第5条 用辅助函数取代复杂的表达式

即使这个函数只用两三次,也是值得的。

第6条 把数据结构直接拆分到多个变量里,不要专门通过下标访问

1
2
items = ('a', 'b')
first, second = items

第7条 尽量用enumerate取代range

enumerate与zip都生成lazy generator。enumerate可以同时迭代出下标与数据。

第8条 用zip函数同时遍历两个生成器

注意:如果提供的迭代器长度不一致,只要其中一个迭代器迭代完毕,zip就会停止。

如果想按最长的那个迭代器遍历,应该改用itertools.zip_longest函数。

第9条 不要在for与while循环后面写else块

for与while后的else(如果循环没有从头到尾执行完,就不会执行else块里的内容)与其他else逻辑(如果没执行前面的语句,那就执行else块)不同。容易造成混淆。

第10条 用赋值表达式减少重复代码

海象运算符(:=),3.8版本引入的语法。

第二章 列表与字典

第11条 学会对序列做切片

类似a[1:5:2],1为起始下标(包含),5为结束下标(不包含),2为步进。

第12条 不要在切片里同时指定起止下标和步进

如果必须同时用这三个,那就分成两次做。而且应该把最能缩减列表长度的操作放在前面。

第13条 通过带星号的unpacking操作来捕获多个元素,不要用切片

1
2
car_ages = [9, 8, 5, 2, 1, 0]
oldest, *others, youngest = car_ages

这种带*的表达式可以出现在任意位置,但必须得有个值与其匹配,并且同一级中只能有一个。

第14条 用sort方法的key参数来表示复杂的排序逻辑

1
tools.sort(key=lambda x: x.weight)

sort函数有reverse参数,可以将默认的升序变成降序。

但reverse会同时改变所有指标的排序方式,对支持一元减操作符的类型可以取反,不支持的话可以多次排序。sort是稳定的排序算法。

第15条 不要过分依赖给字典添加条目时所用的顺序

在Python3.5以及以前的版本中,字典不保证迭代顺序与插入顺序一致。但3.6版本以后会保留键值对添加时所用的顺序。

第16条 用get处理键不在字典中的情况,不要使用in与KeyError

1
count = counters.get(key, 0)

get函数第一个参数指定想查的键,第二个参数指定这个键不在时返回的值。

第17条 用defaultdict处理内部状态中缺失的元素,而不要用setdefault

1
2
3
4
5
6
7
8
from collections import defaultdict

class Visits:
def __init__(self):
self.data = defaultdict(set)

def add(self, country, city):
self.data[country].add(city)

第18条 学会利用__missing__构造依赖键的默认值

传给defaultdict的函数必须是不需要参数的函数,无法创造出需要依赖键名的默认值。

可以定义一个自己的dict子类并实现__missing__方法。

1
2
3
4
5
class Pictures(dict):
def __missing__(self, key):
value = open_picture(key)
self[key] = value
return value

第三章 函数

第19条 不要把函数返回的多个数值拆分到三个以上的变量中

函数返回的其实是个元组。

返回值应该通过小类或namedtuple实例完成。

第20条 遇到意外状况时应该抛出异常,不要返回None

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def careful_divide(a, b):
try:
return a / b
except ZeroDivisionError:
return None

x, y = 1, 0
result = careful_divide(x, y)
if result is None:
print('Invalid inputs')

# 但是,调用者可能会这么用。那么,当x为0时,结果为0,也会被判定为无效输入
if not result:
print('Invalid inputs')

# 可以改为下面这样
def careful_divide(a, b):
try:
return a / b
except ZeroDivisionError:
raise ValueError('Invalid inputs')

第21条 了解如何在闭包里面使用外围作用域中的变量

引用变量时,会按照以下顺序在各个作用域里查找这个变量:

  • 当前函数的作用域
  • 外围作用域
  • 包含当前代码的那个模块所对应的作用域(全局作用域)
  • 内置作用域

但在变量赋值时,如果当前作用域不存在这个变量,那么即使外围作用域里有同名的变量,Python也还是会把这次的赋值当做变量的定义处理。这样可以防止函数中的局部变量污染外围模块。

先用nonlocal声明再赋值就可以修改外围作用域中的变量。

第22条 用数量可变的位置参数给函数设计清晰的参数列表

这些位置参数通常简称varargs或star args。例如:

1
2
3
4
5
6
def log(message, *values):
if not values:
print(message)
else:
values_str = ', '.join(str(x) for x in values)
print(f'{message}: {values_str}')

注意: * 操作符在生成器前,程序可能因为内存耗尽崩溃。

第23条 用关键字参数来表示可选的行为

传参除了按位置,还可以按关键字传递。

1
2
3
4
5
6
7
def remainder(number, divisor):
return number % divisor

remainder(20, 7)
remainder(20, divisor=7)
remainder(number=20, divisor=7)
remainder(divisor=7, number=20)

以上四种写法效果相同。

把 ** 运算符加在字典前面,会把字典里面的键值对以关键字参数的形式传给函数。

关键字参数有三个好处:

  • 让初次阅读的人更容易看懂
  • 可以带有默认值
  • 可以很灵活地扩展函数,不用担心会影响原有的函数调用代码,有助于维护向后兼容

第24条 用None和docstring来描述默认值会变的参数

1
2
3
4
5
6
7
import json

def decode(data, default={}):
try:
return json.loads(data)
except ValueError:
return default

上面的写法系统只会计算一次default参数(在加载这个模块的时候),所以每次调用的时候,给调用者返回的都是同一个 {} ,程序会出现很奇怪的效果:

1
2
3
4
5
6
7
8
9
10
foo = decode('bad data')
foo['stuff'] = 5
bar = decode('also bad')
bar['meep'] = 1
print('Foo: ', foo)
print('Bar: ', bar)

>>>
Foo: {'stuff': 5, 'meep': 1}
Bar: {'stuff': 5, 'meep': 1}

解决这个问题,可以把默认值设成None,并在docstring中说明。

第25条 用只能以关键字指定和只能按位置传入的参数来设计清晰的参数列表

1
2
3
4
def safe_division(numerator, denominator, /,
ndigits=10, *,
ignore_overflow=False,
ignore_zero_division=False)
  • Keyword-only argument是一种只能通过关键字指定而不能通过位置指定的参数。这迫使调用者必须指明这个值是传递给哪一个参数的。这些参数位于 * 符号的右侧。
  • Positional-only argument是一种只能通过位置指定而不能通过关键字指定的参数。这可以降低调用代码与参数名称之间的耦合度。这些参数位于 / 符号的左侧。
  • 位于 / 与 * 之间的参数,可以按位置也可以按关键字指定。

第26条 用functools.wraps定义函数修饰器

修饰器(decorator):

1
2
3
4
5
6
7
def trace(func):
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
print(f'{func.__name__}({args!r}, {kwargs!r}) '
f'-> {result!r}')
return result
return wrapper

但上面这样写有个副作用:修饰器返回的那个值,它的名字不叫原来的名字。

可能会干涉那些需要利用introspection(反射)机制运作的工具,如调试器,help,对象序列化器。

解决这个问题可以用functools.wraps,它会将重要的元数据(metadata)全都从内部函数复制到外部函数:

1
2
3
4
5
6
7
from functools import wraps

def trace(func):
@wraps(func)
def wrapper(*args, **kwargs):
...
return wrapper

第四章 推导与生成

第27条 用列表推导取代map与filter

1
2
3
4
5
6
7
8
9
10
11
a = [1, 2, 3, 4, 5, 6]

# 列表推导
squares = [x**2 for x in a]
even_squres = [x**2 for x in a if x % 2 == 0]

# 字典推导
even_squres_dict = {x: x**2 for x in a if x % 2 == 0}

# 集合推导
threes_cubed_set = {x**3 for x in a if x % 3 == 0}

第28条 控制推导逻辑的子表达式不要超过两个

列表推导支持多个子表达式,如:

1
2
3
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat = [x for row in matrix for x in row]
squred = [[x**2 for x in row] for row in matrix]

也可以连用两个 if ,如果在同一层循环内,默认为and关系。

第29条 用赋值表达式消除推导中的重复代码

1
2
3
4
5
6
7
has_bug = {name: get_batches(stock.get(name, 0), 4)
for name in order
if get_batches(stock.get(name, 0), 8)}

# 上面这段代码可以改成:
has_bug = {name: batches for name in order
if (batches := get_batches(stock.get(name, 0), 8))}

第30条 不要让函数直接返回列表,应该让它逐个生成列表里的值

案例问题:返回字符串每个单词首字母所对应的下标

1
2
3
4
5
6
7
8
def index_words(text):
result = []
if text:
result.append(0)
for index, letter in enumerate(text):
if letter == ' ':
result.append(index + 1)
return result

但上述代码有两个缺点:

  • 看起来杂乱
  • 必须把所有结果都保存在列表中。如果输入数据特别多,可能耗尽内存。

这种函数改为用生成器(generator)实现比较好。

1
2
3
4
5
6
7
def index_words(text):
result = []
if text:
yield 0
for index, letter in enumerate(text):
if letter == ' ':
yield index + 1

生成器函数不用把整个输入值全部读进来,也不用一次就把所有的输出值都算好。

第31条 谨慎地迭代函数所收到的参数

迭代器只能产生一次结果:

1
2
3
4
5
6
7
it = ...
print(list(it))
print(list(it))

>>>
[1, 2, 3]
[]

可以先迭代出结果,再保存。但这样如果输入数据特别多,可能耗尽内存。

可以自定义一个可迭代的容器类:

1
2
3
4
5
6
7
8
class ReadVisits:
def __init__(self, data_path):
self.data_path = data_path

def __iter__(self):
with open(self.data_path) as f:
for line in f:
yield int(line)

第32条 考虑用生成器表达式改写数据量较大的列表推导

1
2
3
4
5
6
7
8
9
10
11
12
13
value = [len(x) for x in open('my_file.txt')]

>>>
[100, 57, 15, 1, 12, 75]

# 生成器表达式
it = (len(x) for x in open('my_file.txt'))
>>>
<generator object <genexpr> at 0x108993dd0>

print(next(it))
>>>
100

生成器表达式之间可以继续迭代组合使用,使用的内存同样不会太多。

第33条 通过yield from把多个生成器连起来用

案例问题:编写程序让图片先快速移动一段时间,再暂停,再慢速移动一段时间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def move(period, speed):
for _ in range(period):
yield speed

def pause(delay):
for _ in range(delay):
yield 0

def animate():
for delta in move(4, 5.0):
yield delta
for delta in pause(3):
yield delta
for delta in move(2, 3.0):
yield delta

但上述写法animate函数里存在很多重复。可以改写为下面的代码,看上去更清晰,并且速度更快:

1
2
3
4
def animate_composed():
yield from move(4, 5.0)
yield from pause(3)
yield from move(2, 3.0)

第34条 不要用send给生成器注入数据

Python的生成器支持send方法,这可以让生成器变成双向通道。send方法可以把参数发给生成器,让它成为上一条yield表达式的求值结果,并将生成器推进到下一条yield表达式,然后把yield表达式右边的值返回给send方法的调用者。然而在一般情况下,我们还是会通过内置的next函数推进生成器,按照这种写法,上一条yield表达式的求值结果总是None。

1
2
3
4
5
6
7
8
9
10
11
def my_generator():
received = yield 1
print(f'received = {received}')

it = iter(my_generator())
output = next(it)
print(f'output = {output}')

>>>
output = 1
received = None

因为“send方法可以把参数发给生成器,让它成为上一条yield表达式的求值结果”,所以首次调用send方法时,只能传None。

send与yield from搭配起来使用可能会导致奇怪的结果。通过迭代器向组合起来的生成器输入数据要比send方法方案好。

第35条 不要通过throw变换生成器的状态

生成器可以把调用者通过throw方法传进来的Exception实例重新抛出。如果调用了这个方法,生成器下次推进时就不会像平常那样,直接走到下一条yield,而是会把传入的异常重新抛出。也可以用try / except复合语句把yield包裹起来,异常捕获后不继续抛出异常,那么生成器函数会推进到下一条yield表达式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def my_generator():
yield 1

try:
yield 2
except MyError:
print('Got MyError!')
else:
yield 3

yield 4

it = my_generator()
print(next(it))
print(next(it))
print(it.throw(MyError('test error')))

>>>
1
2
Got MyError!
4

但这种写法比较难懂。凡是想用生成器与异常实现的功能,通常都可以改用异步机制去做。或通过类的__iter__方法实现生成器,并且专门提供一个方法,让调用者通过这个方法来触发这种特殊的状态变换逻辑。

第36条 考虑用itertools拼装迭代器与生成器

Python内置的itertools模块里有很多函数,可以用来安排迭代器之间的交互关系。实现比较难写的迭代逻辑之前,应该先查看它的文档。下面分三大类列出其中最重要的函数。

连接多个迭代器

chain

chain可以把多个迭代器从头到尾连成一个迭代器。

1
2
3
4
5
it = itertools.chain([1, 2, 3], [4, 5, 6])
print(list(it))

>>>
[1, 2, 3, 4, 5, 6]

repeat

制作一个不停输出某个值的迭代器。

1
2
3
4
5
it = itertools.repeat('hello', 3)
print(list(it))

>>>
['hello', 'hello', 'hello']

cycle

循环输出某段内容之中的各项元素。

1
2
3
4
5
6
it = itertools.cycle([1, 2])
result = [next(it) for _ in range(10)]
print(result)

>>>
[1, 2, 1, 2, 1, 2, 1, 2, 1, 2]

tee

tee可以让一个迭代器分裂成多个平行的迭代器,具体个数由第二个参数指定。

1
2
3
4
5
6
7
8
9
it1, it2, it3 = itertools.tee(['first', 'second'], 3)
print(list(it1))
print(list(it2))
print(list(it3))

>>>
['first', 'second']
['first', 'second']
['first', 'second']

zip_longest

与内置的zip类似,但区别在于,如果源迭代器长度不同,它会用fillvalue参数的值来填补提前耗尽的那些迭代器所留下的空缺。

1
2
3
4
5
6
7
8
9
10
11
12
13
keys = ['one', 'two', 'three']
values = [1, 2]

normal = list(zip(keys, values))
print('zip: ', normal)

it = itertools.zip_longest(keys, values, fillvalue='nope')
longest = list(it)
print('zip_longest:', longest)

>>>
zip: [('one', 1), ('two', 2)]
zip_longest:[('one', 1), ('two', 2), ('three', 'nope')]

过滤源迭代器中的元素

islice

在不拷贝数据的前提下,按照下标切割源迭代器。可以只给出终点,也可以同时给出起点和终点,还可以指定步进值。

1
2
3
4
5
6
7
8
9
10
11
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

first_five = itertools.islice(values, 5)
print('First five: ', list(first_five))

middle_odds = itertools.islice(values, 2, 8, 2)
print('Middle odds:', list(middle_odds))

>>>
First five: [1, 2, 3, 4, 5]
Middle odds: [3, 5, 7]

takewhile

takewhile会一直从源迭代器获取元素,直到某元素让测试函数返回False为止。

1
2
3
4
5
6
7
8
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

less_than_seven = lambda x: x < 7
it = itertools.takewhile(less_than_seven, values)
print(list(it))

>>>
[1, 2, 3, 4, 5, 6]

dropwhile

dropwhile会一直跳过元素,直到某元素让测试函数返回False为止。

1
2
3
4
5
6
7
8
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

less_than_seven = lambda x: x < 7
it = itertools.dropwhile(less_than_seven, values)
print(list(it))

>>>
[7, 8, 9, 10]

filterfalse

与内置的filter相反,会逐个输出源迭代器里使得测试函数返回False的那些元素。

1
2
3
4
5
6
7
8
9
10
11
12
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
evens = lambda x: x % 2 == 0

filter_result = filter(evens, values)
print('Filter: ', list(filter_result))

filter_false_result = itertools.filterfalse(evens, values)
print('Filter false:', list(filter_false_result))

>>>
Filter: [2, 4, 6, 8, 10]
Filter false: [1, 3, 5, 7, 9]

用源迭代器中的元素合成新元素

accumulate

从源迭代器里取出一个元素,并把已经累计的结果与这个元素一起传给表示累加逻辑的函数,然后输出那个函数的计算结果,并把结果当成新的累计值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
sum_reduce = itertools.accumulate(values)
print('Sum: ', list(sum_reduce))

def sum_modulo_20(first, second):
output = first + second
return output % 20

modulo_reduce = itertools.accumulate(values, sum_modulo_20)
print('Modulo:', list(modulo_reduce))

>>>
Sum: [1, 3, 6, 10, 15, 21, 28, 36, 45, 55]
Modulo: [1, 3, 6, 10, 15, 1, 8, 16, 5, 15]

product

product会从一个或多个源迭代器里获取元素,并计算笛卡尔积,可以取代那种多层嵌套的列表推导代码。

1
2
3
4
5
6
7
8
9
single = itertools.product([1, 2], repeat=2)
print('Single: ', list(single))

multiple = itertools.product([1, 2], ['a', 'b'])
print('Multiple:', list(multiple))

>>>
Single: [(1, 1), (1, 2), (2, 1), (2, 2)]
Multiple: [(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b')]

permutations

permutations会考虑源迭代器所能给出的全部元素,并逐个输出由其中N个元素形成的每种有序排列方式,元素相同但顺序不同,算作两种排列。

1
2
3
4
5
6
it = itertools.permutations([1, 2, 3, 4], 2)
print(list(it))

>>>
[(1, 2), (1, 3), (1, 4), (2, 1), (2, 3), (2, 4),
(3, 1), (3, 2), (3, 4), (4, 1), (4, 2), (4, 3)]

combinations

combinations会考虑源迭代器所能给出的全部元素,并逐个输出由其中N个元素形成的每种无序组合方式,元素相同但顺序不同,算作同一种组合。

1
2
3
4
5
it = itertools.combinations([1, 2, 3, 4], 2)
print(list(it))

>>>
[(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]

combinations_with_replacement

combinations_with_replacement与combinations类似,但它允许同一个元素在组合里多次出现。

1
2
3
4
5
6
it = itertools.combinations_with_replacement([1, 2, 3, 4], 2)
print(list(it))

>>>
[(1, 1), (1, 2), (1, 3), (1, 4), (2, 2),
(2, 3), (2, 4), (3, 3), (3, 4), (4, 4)]

第五章 类与接口

Python完全支持继承多态封装等各种机制。熟悉类与接口的用法,可以帮助我们写出易于维护的代码。

第37条 用组合起来的类来实现多层结构,不要用嵌套的内置类型

遇到比较复杂的需求,那么不要再嵌套字典、元组、集合等内置类型,而是应该编写一批新类并让这些类形成一套体系。

把多层嵌套的内置类型重构为类体系

使用具名元组(namedtuple)定义小型的类以表示不可变的数据:

1
2
3
from collections import namedtuple

Grade = namedtuple('Grade', ('score', 'weight'))

可通过位置参数构造,也可通过关键字参数构造。方便后期改写为普通的类。

namedtuple的局限:

  • 无法指定默认参数值
  • 属性值仍然可以通过数字下标与迭代访问

第38条 让简单的接口接受函数,而不是类的实例

许多内置的API允许传入某个函数来定制它的行为,这种函数可以叫做挂钩(hook)。API在执行过程中,会回调(call back)这些挂钩函数。例如:

1
2
3
4
5
6
names = ['Charles', 'Plato', 'Alice', 'Bob']
names.sort(key=len)
print(names)

>>>
['Bob', 'Plato', 'Alice', 'Charles']

其他语言中hook可能会用抽象类定义,但在Python中,许多hook都是无状态的函数(创建时不需要参数,也没有任何内存),带有明确的参数与返回值。

某个类定义了__call__特殊方法那么就是callable,能够像函数那样调用。

如果想用函数来维护状态,可以考虑定义一个带有__call__方法的新类,而不要用有状态的闭包实现。

1
2
3
4
5
6
7
8
9
10
11
class BetterCountMissing:
def __init__(self):
self.added = 0

def __call__(self):
self.added += 1
return 0

counter = BetterCountMissing()
assert counter() == 0
assert callable(counter)

第39条 通过@classmethod多态来构造同一体系中的各类对象

在Python中,不仅对象支持多态,类也支持多态。(这里的多态是指在超类上调用方法,实际触发的是子类的同名方法)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class GenericInputData:
def read(self):
raise NotImplementedError

@classmethod
def generate_inputs(cls, config):
raise NotImplementedError

class PathInputData(GenericInputData):
...

@classmethod
def generate_inputs(cls, config):
data_dir = config['data_dir']
for name in os.listdir(data_dir):
yield cls(os.path.join(data_dir, name))

Python只允许每个类有一个构造方法,也就是__init__方法。如果想在超类中用通用的代码构造子类实例,可以考虑@classmethod方法,并在里面用cls(…)的形式构造具体的子类对象。

第40条 通过super初始化超类

以前有种简单的写法,能在子类里面执行超类的初始化逻辑(直接在超类名称上调用__init__方法并把子类实例传进去)。

1
2
3
4
5
6
7
class MyBaseClass:
def __init__(self, value):
self.value = value

class MyChildClass(MyBaseClass):
def __init__(self):
MyBaseClass.__init__(self, 5)

但是容易出现问题:

  • 超类的构造逻辑不一定会按照它们在子类class语句中的声明顺序执行,而是依照__init__的调用顺序
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class MyBaseClass:
def __init__(self, value):
self.value = value

class TimesTwo:
def __init__(self):
self.value *= 2

class PlusFive:
def __init__(self):
self.value += 5

class AnotherWay(MyBaseClass, PlusFive, TimesTwo):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)

bar = AnotherWay(5)
print(bar.value)

>>>
15
  • 无法正确处理菱形继承
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class MyBaseClass:
def __init__(self, value):
self.value = value

class TimesSeven(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value *= 7

class PlusNine(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value += 9

class ThisWay(TimesSeven, PlusNine):
def __init__(self, value):
TimesSeven.__init__(self, value)
PlusNine.__init__(self, value)

foo = ThisWay(5)
print('Should be (5 * 7) + 9 = 44 but is', foo.value)

>>>
Should be (5 * 7) + 9 = 44 but is 14

当ThisWay调用第二个超类的__init__时,那个方法会再度触发MyBaseClass的__init__,导致self.value重新变成5。为了解决这些问题,Python内置了super函数并且规定了标准的方法解析顺序(method resolution order,MRO)。super能够确保菱形继承体系内共同超类只初始化一次。MRO可以确定超类之间的初始化顺序,它遵循C3线性化算法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class MyBaseClass:
def __init__(self, value):
self.value = value

class TimesSevenCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value *= 7

class PlusNineCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value += 9

class GoodWay(TimesSevenCorrect, PlusNineCorrect):
def __init__(self, value):
super().__init__(value)

foo = GoodWay(5)
print('Should be 7 * (5 + 9) = 98 but is', foo.value)

>>>
Should be 7 * (5 + 9) = 98 but is 98

超类之间的初始化顺序,要由子类的MRO确定,它可以通过mro方法查询:

1
2
3
4
5
6
7
8
9
mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro())
print(mro_str)

>>>
<class '__main__.GoodWay'>
<class '__main__.TimesSevenCorrect'>
<class '__main__.PlusNineCorrect'>
<class '__main__.MyBaseClass'>
<class 'object'>

super函数也可以用双参数的形式调用。第一个参数表示从这个类型开始(不含该类型本身)按照方法解析顺序(MRO)向上搜索,而解析顺序则要由第二个参数所在类型的_mro_决定。例如,按照下面这种写法,如果在super所返回的内容上调用__init__.方法,那么程序会从ExplicitTrisect类型开始(不含该类型本身)按照MRO向上搜索,直至找到这样的__init__方法为止,而解析顺序是由第二个参数(self)所属的类型(ExplicitTrisect)决定的,所以解析顺序是ExplicitTrisect -> MyBaseClass -> object。

1
2
3
4
class ExplicitTrisect(MyBaseClass):
def __init__(self, value):
super(ExplicitTrisect, self).__init__(value)
self.value /= 3

__init__方法里面通过super初始化实例时,不需要采用双参数的形式,而是可以直接采用不带参数的写法调用super,这样Python编译器会自动将__class__和self当成参数传递进去。所以,下面这两种写法跟刚才那种写法是同一个意思。

1
2
3
4
5
6
7
8
9
class AutomaticTrisect(MyBaseClass):
def __init__(self, value):
super(__class__, self).__init__(value)
self.value /= 3

class ImplicitTrisect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value /= 3

第41条 考虑用mix-in类来表示可组合的功能

Python支持多重继承,但应该尽量少用。如果既要通过多重继承来方便地封装逻辑,又想避开可能出现的问题,应该把有待继承的类写成min-in类。这种类只提供一小套方法给子类去沿用,而不定义自己实例级别的属性,也不需要__init__构造函数。

案例问题:把内存中的Python对象表示成字典形式以便做序列化处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class ToDictMixin:
def to_dict(self):
return self._traverse_dict(self.__dict__)

def _traverse_dict(self, instance_dict):
output = {}
for key, value in instance_dict.items():
output[key] = self._traverse(key, value)
return output

def _traverse(self, key, value):
if isinstance(value, ToDictMixin):
return value.to_dict()
elif isinstance(value, dict):
return self._traverse_dict(value)
elif isinstance(value, list):
return [self._traverse(key, i) for i in value]
elif hasattr(value, '_dict_'):
return self._traverse_dict(value.__dict__)
else:
return value

下面以二叉树为例,演示如何使表示二叉树的BinaryTree类具备刚才那个mix-in所提供的功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class BinaryTree(ToDictMixin):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right

tree = BinaryTree(10,
left=BinaryTree(7, right=BinaryTree(9)),
right=BinaryTree(13, left=BinaryTree(11)))
print(tree.to_dict())

>>>
{'value': 10,
'left': {'value': 7,
'left': None,
'right': {'value': 9, 'left': None, 'right': None}},
'right': {'value': 13,
'left': {'value': 11, 'left': None, 'right': None},
'right': None}}

mix-in最妙的地方在于,子类既可以沿用它所提供的功能,又可以对其中一些地方做自己的处理。

例如,我们从普通的二叉树(BinaryTree)派生了一个子类,让这种特殊的BinaryTreeWithParent二叉树能够把指向上级节点的引用保留下来。但问题是,这种二叉树的to_dict方法是从ToDictMixin继承来的,它所触发的_traverse方法,在面对循环引用时,会无休止地递归下去。

为了避免无限循环,可以覆盖BinaryTreeWithParent._traverse方法,让它对指向上级节点的引用做专门的处理,对其他值则沿用继承的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BinaryTreeWithParent(BinaryTree):
def __init__(self, value, left=None,
right=None, parent=None):
super().__init__(value, left=left, right=right)
self.parent = parent

def _traverse(self, key, value):
if (isinstance(value, BinaryTreeWithParent) and
key == 'parent'):
return value.value
else:
return super()._traverse(key, value)

root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
print(root.to_dict())

>>>
{'value': 10,
'left': {'value': 7,
'left': None,
'right': {'value': 9,
'left': None,
'right': None,
'parent': 7},
'parent': 10},
'right': None,
'parent': None}

多个mix-in可以组合起来使用。

第42条 优先考虑用public属性表示应受保护的数据,不要用private属性表示

Python的类属性只有两种访问级别,public与private:

1
2
3
4
5
6
7
class MyObject:
def __init__(self):
self.public_field = 5
self.__private_field = 10

def get_private_field(self):
return self.__private_field

private字段只给这个类自己使用,不给子类以及其他。

实现方式为变换属性名称,如上例中__private_field则变为_MyObject__private_field。

  • Python编译器无法绝对禁止外界访问 private属性。
  • 从一开始就应该考虑允许其他类能继承这个类,并利用其中的内部API与属性去实现更多功能,而不是把它们藏起来。
  • 把需要保护的数据设计成protected字段,并用文档加以解释,而不要通过private属性限制访问。
  • 只有在子类不受控制且名称有可能与超类冲突时,才可以考虑给超类设计private属性。

第43条 自定义的容器类型应该从collections.abc继承

如果要编写的新类比较简单,那么可以直接从Python的容器类型里面继承。

1
2
3
4
5
6
7
8
9
class FrequencyList(list):
def __init__(self, members):
super().__init__(members)

def frequency(self):
count = {}
for item in self:
count[item] = count.get(item, 0) + 1
return count

为了方便定制容器,Python内置的collections.abc模块定义了一系列抽象基类,把每种容器类型应该提供的所有常用方法都写了出来,只需要继承这些抽象基类。同时,如果忘了实现某些必备的方法,程序会报错。

1
2
3
4
5
6
7
8
from collections.abc import Sequence

class TreeNode(Sequence):
def __getitem__(self, index: int):
pass

def __len__(self) -> int:
pass

第六章 元类与属性

元类能够拦截Python的class语句,让系统每次定义类的时候,都能实现某些特殊的行为。

除了元类,Python还可以动态地定制属性访问操作。

第44条 用纯属性与修饰器取代旧式的setter与getter方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class VoltageResistance:
def __init__(self, ohms):
self.ohms = ohms
self._voltage = 0
self.current = 0

@property
def voltage(self):
return self._voltage

@voltage.setter
def voltage(self, voltage):
self._voltage = voltage
self.current = self._voltage / self.ohms

@property
def ohms(self):
return self._ohms

@ohms.setter
def ohms(self, ohms):
if hasattr(self, '_ohms'):
raise AttributeError('Ohms is immutable!')
self._ohms = ohms

@property最大的缺点是,通过它而编写的属性取及属性设置方法只能由子类共享。与此无关的类不能共用这份逻辑。但是没关系,Pvthon还支持描述符(descriptor,参见第46条),我们可以利用这种机制把早前编写的属性获取与属性设置逻辑复用到其他许多地方。

  • 给新类定义接口时,应该先从简单的public属性写起,避免定义setter与 getter方法。
  • 如果在访问属性时确实有必要做特殊的处理,那就通过@property来定义获取属性与设置属性的方法。
  • 实现@property方法时,应该遵循最小惊讶原则,不要引发奇怪的副作用。
  • @property方法必须执行得很快。复杂或缓慢的任务,尤其是涉及IO或者会引发副作用的那些任务,还是用普通的方法来实现比较好。

第45条 考虑用@property实现新的属性访问逻辑,不要急着重构原有的代码

  • 可以利用 @property给已有的实例属性增加新的功能。
  • 可以利用@property 逐渐改善数据模型而不影响已经写好的代码。
  • 如果发现@property使用太过频繁,那可能就该考虑重构这个类了,同时按照旧办法使用这个类的那些代码可能也要重构。

第46条 用描述符来改写需要复用的@property

Python内置的@property最大的缺点就是不方便复用。我们不能把它修饰的方法所使用的逻辑,套用在同一个类的其他属性上面,也不能在无关的类里面复用。

这样的功能最好通过描述符实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from weakref import WeakKeyDictionary

class Grade:
def __init__(self):
# 用WeakKeyDictionary是让系统自动处理内存泄漏问题
# 用dict是要把每个Exam实例在这个属性上的取值都记录下来
self._value = WeakKeyDictionary()

def __get__(self, instance, instance_type):
if instance is None:
return self
return self._value.get(instance, 0)

def __set__(self, instance, value):
if not (0 <= value <= 100):
raise ValueError('Grade must be between 0 and 100')
self._value[instance] = value

class Exam:
math_grade = Grade()
writing_grade = Grade()
science_grade = Grade()

class Exam:
math_grade = Grade()
writing_grade = Grade()
science_grade = Grade()

第47条 针对惰性属性使用__getattr__、__getattribute__及__setattr__

如果类中定义了__getattr__,那么每当访问该类对象属性,而且实例字典里又找不到这个属性时,系统就会触发__getattr__方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class LazyRecord:
def __init__(self):
self.exists = 5

def __getattr__(self, item):
value = f'Value for {item}'
setattr(self, item, value)
return value

data = LazyRecord()
print('Before:', data.__dict__)
print('foo: ', data.foo)
print('After :', data.__dict__)

>>>
Before: {'exists': 5}
foo: Value for foo
After : {'exists': 5, 'foo': 'Value for foo'}

__getattribute__方法只要访问对象中的属性就会触发,无论是否在__dict__字典里。

__setattr__方法只要给实例中的属性赋值就会触发。

注意:在实现__getattribute__方法与__setattr__方法时,如果要使用本对象的普通属性,那么应该通过super()也就是object来使用,以避免无限递归。

第48条 用__init_subclass__验证子类写得是否正确

元类最简单的一种用法是验证某个类定义得是否正确。如果要构建一套比较复杂的类体系,那我们可能得确保这套体系中的类采用的都是同一种风格,为此我们可能需要判断这些类有没有重写必要的方法,或者判断类属性之间的关系是否合理。元类提供了一种可靠的手段,只要根据这个元类来定义新类,就能用元类中的验证逻辑核查新类的代码写得是否正确。

一般来说,我们会在类的__init__方法里面检查新对象构造得是否正确(参见第44条)。但有的时候,整个类的写法可能都是错的,而不单单是该类的某个对象构造得有问题,所以我们想尽早拦住这种错误。例如,当程序刚刚启动并把包含这个类的模块加载进来时,我们就想验证这个类写得对不对,此时便可利用元类来实现。

在讲解如何用自定义的元类验证子类之前,我们首先必须明白元类的标准用法。元类应该从type之中继承。在默认情况下,系统会把通过这个元类所定义的其他类发送给元类的__new__方法,让该方法知道那个类的class语句是怎么写的。下面就定义这样一个元类,如果用户通过这个元类来定义其他类,那么在那个类真正构造出来之前,我们可以先在__new__里面观察到它的写法并做出修改。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Meta(type):
def __new__(mcs, name, bases, class_dict):
print(f'* Running {mcs}.__new__ for {name}')
print('Bases:', bases)
print(class_dict)
return type.__new__(mcs, name, bases, class_dict)

class MyClass(metaclass=Meta):
stuff = 123

def foo(self):
pass

class MySubClass(MyClass):
other = 567

def bar(self):
pass

>>>
* Running <class '__main__.Meta'>.__new__ for MyClass
Bases: ()
{'__module__': '__main__', '__qualname__': 'MyClass', 'stuff': 123, 'foo': <function MyClass.foo at 0x000001870F665DA0>}
* Running <class '__main__.Meta'>.__new__ for MySubClass
Bases: (<class '__main__.MyClass'>,)
{'__module__': '__main__', '__qualname__': 'MySubClass', 'other': 567, 'bar': <function MySubClass.bar at 0x000001870F665D00>}

我们可以在元类的__new__方法里而添加―些代码,用来判断根据这个元类所定义的类的各项参数是否合理。例如,要用不同的类来表示边数不同的多边形(polygon)。如果把这些类都纳入同一套体系,那么可以定义这样一个元类,让该体系内的所有类都受它约束。我们在这个元类的__new__里面检查那些类的边数(sides是否有效。注意,不要把检查逻辑运用到类体系的顶端,也就是基类Polygon上面。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ValidatePolygon(type):
def __new__ (mcs, name, bases, class_dict):
#Only validate subc1asses of the Polygon class
if bases:
if class_dict['sides'] < 3:
raise ValueError('Polygons need 3+ sides ')
return type.__new__(mcs, name, bases, class_dict)

class Polygon(metaclass=ValidatePolygon):
sides = None # Must be specified by subclasses

@classmethod
def interior_angles(cls):
return (cls.sides - 2) * 180

class Triangle(Polygon):
sides = 3

class Rectangle(Polygon):
sides = 4

print('Before class')

class Line(Polygon):
print('Before sides')
sides = 2
print('After sides')

print('After class')

>>>
Before class
Before sides
After sides
Traceback (most recent call last):
...
ValueError: Polygons need 3+ sides

这样一项基本的任务竟然要写这么多代码才能实现。好在 Python 3.6引人了一种简化的写法,能够直接通过__init_subclass__这个特殊的类方法实现相同的功能,这样就不用专门定义元类了。下面我们改用这个机制来实现与刚才相同的验证逻辑。

1
2
3
4
5
6
7
8
9
10
11
class BetterPolygon:
sides = None # Must be specified by subclasses

def __init_subclass__(cls):
super().__init_subclass__()
if cls.sides < 3:
raise ValueError('Polygons need 3+ sides')

@classmethod
def interior_angles(cls):
return (cls.sides - 2) * 180

现在的代码简短多了,完全不需要定义ValidatePolygon这样一个元类。

  • 如果某个类是根据元类所定义的,那么当系统抑该类的class语句体全部处理完之后,就会将这个类的写法告诉元类的__new__方法。
  • 可以利用元类在类创建完成前检视或修改开发者根据这个元类所定义的其他类,但这种机制通常显得有点笨重。
  • __init_subclass__能够用来检查子类定义得是否合理,如果不合理,那么可以提前报错,让程序无法创建出这种子类的对象。
  • 在分层的或者涉及多重继承的类体系里面,一定别忘了在你写的这些类的__init_subclass__内通过super()来调用超类的__init_subclass__方法,以便按照正确的顺序触发各类的验证逻辑。

第49条 用__init_subclass__记录现有的子类

  • 类注册(Class registration)是个相当有用的模式,可以用来构建模块式的Python程序。
  • 我们可以通过基类的元类把用户从这个基类派生出来的子类自动注册给系统。
  • 利用元类实现类注册可以防止由于用户忘记注册而导致程序出现问题。
  • 优先考虑通过__init_subclass__实现自动注册,而不要用标准的元类机制来实现,因为__init_subclass__更清晰,更便于初学者理解。

第50条 用__set_name__给类属性加注解

有了元类、DatabaseRow基类以及修改过的Field描述符,我们在给客户类定义字段时,就不用手工传入字段名了,代码也不像之前那样冗余了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Meta(type):
def __new__(mcs, name, bases, class_dict):
for key, value in class_dict.items():
if isinstance(value, Field):
value.name = key
value.internal_name = '_' + key
cls = type.__new__(mcs, name, bases, class_dict)
return cls

class Field:
def __init__(self):
# These wi11be assigned by the metaclass
self.name = None
self.internal_name = None

def __get__(self, instance, instance_type):
if instance is None:
return self
return getattr(instance, self.internal_name, '')

def __set__(self, instance, value):
setattr(instance, self.internal_name, value)

class DatabaseRow(metaclass=Meta):
pass

# 有些重复
class Customer:
first_name = Field('first_name')
last_name = Field('last_name')
prefix = Field('prefix')
suffix = Field('suffix')

# 改成下面这种
class BetterCustomer(DatabaseRow):
first_name = Field()
last_name = Field()
prefix = Field()
suffix = Field()

但这个方法的缺点是,要想在类中声明Field字段,这个类必须从DatabaseRow继承。假如忘了或者结构上不方便这样继承,那么代码就无法正常运行。

这个问题可以通过给描述符定义__set_name__特殊方法来解决。这是Python 3.6引入的新功能:如果某个类用这种描述符的实例来定义字段,那么系统就会在描述符上面触发这个特殊方法。系统会把采用这个描述符实例作字段的那个类以及字段的名称,当成参数传给__set_name__。下面我们将 Meta.__new__之中的逻辑移动到Field描述符元类__set_name__里面,这样一来,就不用定义元类了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Field:
def __init__(self):
# These wi11be assigned by the metaclass
self.name = None
self.internal_name = None

def __set_name__(self, owner, name):
self.name = name
self.internal_name = '_' + name

def __get__(self, instance, instance_type):
if instance is None:
return self
return getattr(instance, self.internal_name, '')

def __set__(self, instance, value):
setattr(instance, self.internal_name, value)
  • 我们可以通过元类把利用这个元类所定义的其他类拦截下来,从而在程序开始使用那些类之前,先对其中定义的属性做出修改。
  • 描述符与元类搭配起来,可以形成一套强大的机制,让我们既能采用声明式的写法来定义行为,又能在程序运行时检视这个行为的具体执行情况。
  • 你可以给描述符定义__set_name__方法,让系统把使用这个描述符做属性的那个类以及它在类里的属性名通过方法的参数告诉你。
  • 用描述符直接操纵每个实例的属性字典,要比把所有实例的属性都放到一份字典里更好,因为后者要求我们必须使用weakref内置模块之中的特殊字典来记录每个实例的属性值以防止内存泄漏。

第51条 优先考虑通过类修饰器来提供可组合的扩充功能,不要使用元类

尽管元类允许我们用各种方式定制其他类的创建逻辑,但有些情况它未必能处理得很好。例如,要写一个辅助函数来修饰类中的每个方法,把这些方法执行时所用的参数,返回值以及抛出的异常都处理好,以便调试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from functools import wraps

def trace_func(func):
if hasattr(func, 'tracing'):
return func

@wraps(func)
def wrapper(*args, **kwargs):
result = None
try:
result = func(*args, **kwargs)
return func
except Exception as e:
result = e
raise
finally:
print(f'{func.__name__}({args!r}, {kwargs!r}) -> '
f'{result!r}')

wrapper.tracing = True
return wrapper

class TraceDict(dict):
@trace_func
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@trace_func
def __setitem__(self, *args, **kwargs):
super().__setitem__(*args, **kwargs)

@trace_func
def __getitem__(self, *args, **kwargs):
super().__getitem__(*args, **kwargs)

...

但这样必须在子类中把需要受@trace_func修饰的方法全都重写一遍,即便子类只想沿用超类的实现方式。

解决这个问题,其中一个办法是通过元类自动修饰那个类的所有方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import types

trace_types = (
types.MethodType,
types.FunctionType,
types.BuiltinFunctionType,
types.BuiltinMethodType,
types.MethodDescriptorType,
types.ClassMethodDescriptorType
)

class TraceMeta(type):
def __new__(mcs, name, bases, class_dict):
class_ = super().__new__(mcs, name, bases, class_dict)

for key in dir(class_):
value = getattr(class_, key)
if isinstance(value, trace_types):
wrapped = trace_func(value)
setattr(class_, key, wrapped)

return class_

class TraceDict(dict, metaclass=TraceMeta):
pass

但元类之间很难组合,而类修饰器比较灵活,它们可以施加在同一个类上,并且不会发生冲突。

1
2
3
4
5
6
7
8
9
10
11
def trace(class_):
for key in dir(class_):
value = getattr(class_, key)
if isinstance(value, trace_types):
wrapped = trace_func(value)
setattr(class_, key, wrapped)
return class_

@trace
class TraceDict(dict):
pass

第七章 并发与并行

第52条 用subprocess管理子进程

Python里面有许多方法都可以运行子进程,其中最好的办法是通过内置的subprocess模块来管理。

1
2
3
4
5
6
7
8
9
10
import subprocess

result = subprocess.run(
['echo', 'Hello from the child'],
capture_output=True,
encoding='utf-8'
)

result.check_returncode()
print(result.stdout)

子进程可以独立于父进程运行:

1
2
3
4
5
6
import subprocess

proc = subprocess.Popen(['sleep', '1'])
while proc.poll() is None:
print('Working...')
print('Exit status', proc.poll())

把子进程从父进程剥离,可以让程序平行地运行多条子进程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import subprocess
import time

start = time.time()
sleep_procs = []
for _ in range(10):
proc = subprocess.Popen(['sleep', '1'])
sleep_procs.append(proc)

for proc in sleep_procs:
proc.communicate()

end = time.time()
delta = end - start

>>>
1.05s

还可以把数据通过管道发送给子进程运行的外部命令,然后将那条命令的输出结果获取到程序中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import subprocess
import os

def run_encrypt(data):
env = os.environ.copy()
env['password'] = 'zkf5cve|ce\e*-$^V@d23'
proc = subprocess.Popen(
['openssl', 'enc', '-des3', '-pass', 'env:password'],
env=env,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE
)
proc.stdin.write(data)
proc.stdin.flush()
return proc

这些平行运行的子进程还可以分别与另一套平行的子进程对接,形成许多条平行的管道。

在调用communicate方法时可以指定timeout参数,用来停掉不正常的子进程。

1
2
3
4
5
6
7
8
import subprocess

proc = subprocess.Popen(['sleep', '10'])
try:
proc.communicate(timeout=0.1)
except subprocess.TimeoutExpired:
proc.terminate()
proc.wait()

第53条 可以用线程执行阻塞式I/O,但不要用它做并行计算

CPython的全局解释器锁(Global Interpreter Lock,GIL)限制,无法实现多核心并行计算。(笔者注:在2023年7月。Python 团队已经正式接受了删除 GIL 的这个提议,并将其设置为可选模式,但等到彻底去除GIL估计还有5年的时间)

  • 即便计算机具备多核的 CPU,Python线程也无法真正实现并行,因为它们会受全局解释器锁(GIL)牵制。
  • 虽然Python的多线程机制受GIL影响,但还是非常有用的,因为我们很容易就能通过多线程模拟同时执行多项任务的效果。
  • 多条Python线程可以并行地执行多个系统调用,这样就能让程序在执行阻塞式的I/O任务时,继续做其他运算。

第54条 利用Lock防止多个线程争用同一份数据

程序在执行完当前这条字节码指令之后,可能会被Python系统切换走,等它稍后切换回来继续执行下一条字节码指令时,当前的数据或许已经与实际值脱节了,因为中途切换进来的其他线程可能更新过这个值。所以,多个线程同时访问同一个对象是很危险的。每条线程在操作这份数据时,都有可能遭到其他线程打扰,因此数据之中的固定关系或许已经被别的线程破坏了,这会令程序陷入混乱状态。为了避免数据争用,可以使用threading中的Lock类,它相当于互斥锁(mutex)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from threading import Thread, Lock

class Counter:
def __init__(self):
self.lock = Lock()
self.count = 0

def increment(self, offset):
with self.lock:
self.count += offset

def worker(how_many, counter):
for _ in range(how_many):
counter.increment(1)

how_many = 10**5
counter = Counter()

threads = []
for i in range(5):
thread = Thread(target=worker,
args=(how_many, counter))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

print(counter.count)

>>>
500000

第55条 用Queue来协调各线程之间的工作进度

  • 管道非常适合用来安排多阶段的任务,让我们能够把每一阶段都交给各自的线程去执行,这尤其适合用在I/O密集型的程序里面。
  • 构造这种并发的管道时,有很多问题需要注意,例如怎样防止线程频繁地查询队列状态,怎样通知线程尽快结束操作,以及怎样防止管道出现拥堵等。
  • 我们可以利用Queue类所具有的功能来构造健壮的管道系统,因为这个类提供了阻塞式的入队(put)与出队(get)操作,而且可以限定缓冲区的大小,还能够通过task-done 与join来确保所有元素都已处理完毕。

第56条 学会判断什么场合必须做并发

案例:康威生命游戏,经典的有限状态自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
ALIVE = '*'
EMPTY = '-'

class Grid:
def __init__(self, height, width):
self.height = height
self.width = width
self.rows = []
for _ in range(self.height):
self.rows.append([EMPTY] * self.width)

def get(self, y, x):
return self.rows[y % self.height][x % self.width]

def set(self, y, x, state):
self.rows[y % self.height][x % self.width] = state

def __str__(self):
output = ''
for row in self.rows:
for cell in row:
output += cell
output += '\n'
return output

def count_neighbors(y, x, get):
n_ = get(y - 1, x + 0) # North
ne = get(y - 1, x + 1) # Northeast
e_ = get(y + 0, x + 1) # East
se = get(y + 1, x + 1) # Southeast
s_ = get(y + 1, x + 0) # South
sw = get(y + 1, x - 1) # Southwest
w_ = get(y + 0, x - 1) # West
nw = get(y - 1, x - 1) # Northwest
neighbor_states = [n_, ne, e_, se, s_, sw, w_, nw]
count = 0
for state in neighbor_states:
if state == ALIVE:
count += 1
return count

def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state

def step_cell(y, x, get, set):
state = get(y, x)
neighbors = count_neighbors(y, x, get)
next_state = game_logic(state, neighbors)
set(y, x, next_state)

def simulate(grid):
next_grid = Grid(grid.height, grid.width)
for y in range(grid.height):
for x in range(grid.width):
step_cell(y, x, grid.get, next_grid.set)
return next_grid

# 多人在线的话,加入IO操作
def game_logic(state, neighbors):
# Do some blocking input/output in here:
data = my_socket.recv(100)

game_logic这种写法的问题在于,它会拖慢整个程序的速度。如果game_logic函数每次执行的I/O操作需要100毫秒才能完成(与国外的玩家通信一个来回,确实有可能需要这么长时间),那么把整张网格向前推进一代最少需要45秒,因为 simulate函数在推进网格时,是一个一个单元格来计算的。它需要把这45个单元格按顺序计算一遍。这对于网络游戏来说,实在太慢,让人没耐心玩下去。另外,这个方案也无法扩展,假如单元格的数量增加到一万,那么计算新一代网格所花的总时间就会超过15分钟。
若想降低延迟时间,应该平行地执行这些I/O操作,这样的话,无论网格有多大,都只需要100毫秒左右就能推进到下一代。针对每个工作单元开辟一条执行路径,这种模式叫作扇出(fan-out),对于本例来说,工作单元指的是网格中的单元格。然后,要等待这些并发的工作单元全部完工,才能执行下一个环节,这种模式叫作扇入(fan-in),对于本例来说,下一个环节指的是让整张网格进入新的一代。
Python提供了许多内置的工具,可以实现fan-out与fan-in模式,这些工具各有利弊。我们要了解每种方案的优点和缺点,这样才能用最合适的工具来应对具体的需求。下面几条会继续以生命游戏为例,详细讲解这些工具(参见第57条、第58条、第59条与第60条)。

  • 程序范围变大、需求变复杂之后,经常要用多条路径平行地处理任务。
  • fan-out与fan-in是最常见的两种并发协调(concurrency coordination)模式,前者用来生成一批新的并发单元,后者用来等待现有的并发单元全部完工。
  • Python提供了很多种实现fan-out与fan-in的方案。

第57条 不要在每次fan-out时都新建一批Thread实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class LockingGrid(Grid):
def __init__(self, height, width):
super().__init__(height, width)
self.lock = Lock()

def __str__(self):
with self.lock:
return super().__str__()

def get(self, y, x):
with self.lock:
return super().get(y, x)

def set(self, y, x, state):
with self.lock:
return super().set(y, x, state)

def simulate_threaded(grid):
next_grid = LockingGrid(grid.height, grid.width)

threads = []
for y in range(grid.height):
for x in range(grid.width):
args = (y, x, grid.get, next_grid.set)
thread = Thread(target=step_cell, args=args)
thread.start() # Fan out
threads.append(thread)

for thread in threads:
thread.join() # Fan in

return next_grid

如果需要一大批执行路径分头去执行某项任务,而且还要频繁地启动并停止这批执行路径,那么每次都需手工新建一批线程,这肯定不是个好办法。Python提供了其他几种更合适的方案(参见第58条、第59条与第60条)。

  • 每次都手工创建一批线程,是有很多缺点的,例如:创建并运行大量线程时的开销比较大,每条线程的内存占用量比较多,而且还必须采用Lock等机制来协调这些线程。
  • 线程本身并不会把执行过程中遇到的异常抛给启动线程或者等待该线程完工的那个人,所以这种异常很难调试。

第58条 学会正确地重构代码,以便用Queue做并发

  • 把队列(Queue)与一定数量的工作线程搭配起来,可以高效地实现fan-out(分派)与fan-in(归集)。
  • 为了改用队列方案来处理IO,我们重构了许多代码,如果管道要分成好几个环节,那么要修改的地方会更多。
  • 利用队列并行地处理IO任务,其处理IO任务量有限,我们可以考虑用Python内置的某些功能与模块打造更好的方案。

第59条 如果必须用线程做并发,那就考虑通过ThreadPoolExecutor实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from concurrent.futures import ThreadPoolExecutor

def simulate_pool(pool, grid):
next_grid = LockingGrid(grid.height, grid.width)

futures = []
for y in range(grid.height):
for x in range(grid.width):
args = (y, x, grid.get, next_grid.set)
future = pool.submit(step_cell, *args) # Fan out
futures.append(future)

for future in futures:
future.result() # Fan in

return next_grid

columns = ColumnPrinter()
with ThreadPoolExecutor(max_workers=10) as pool:
for i in range(5):
columns.append(str(grid))
grid = simulate_pool(pool, grid)

ThreadPoolExecutor方案仍然有个很大的缺点,就是I/O并行能力不高,即便把max_workers设成100,也无法高效地应对那种有一万多个单元格,且每个单元格都要同时做IO的情况。如果你面对的需求,没办法用异步方案解决,而是必须执行完才能往后走(例如文件I/O),那么ThreadPoolExecutor是个不错的选择。然而在许多情况下,其实还有并行能力更强的办法可以考虑(参见第60条)。

  • 利用 ThreadPoolExecutor,我们只需要稍微调整一下代码,就能够并行地执行简单的I/O操作,这种方案省去了每次 fan-out(分派)任务时启动线程的那些开销。
  • 虽然 ThreadPoolExecutor不像直接启动线程的方案那样,需要消耗大量内存,但它的IO并行能力也是有限的。因为它能够使用的最大线程数需要提前通过max_workers参数指定。

第60条 用协程实现高并发的I/O

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
async def game_logic(state, neighbors):
# Do some input/output in here:
data = await my_socket.read(50)

async def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state

async def step_cell(y, x, get, set):
state = get(y, x)
neighbors = count_neighbors(y, x, get)
next_state = await game_logic(state, neighbors)
set(y, x, next_state)

async def simulate(grid):
next_grid = Grid(grid.height, grid.width)

tasks = []
for y in range(grid.height):
for x in range(grid.width):
task = step_cell(
y, x, grid.get, next_grid.set) # Fan out
tasks.append(task)

await asyncio.gather(*tasks) # Fan in

return next_grid

columns = ColumnPrinter()
for i in range(5):
columns.append(str(grid))
grid = asyncio.run(simulate(grid)) # Run the event loop

协程的优点是,能够把那些与外部环境交互的代码(例如I/O调用)与那些实现自身需求的代码(例如事件循环)解耦。这让我们可以把重点放在实现需求所用的逻辑上面,而不用专门花时间去写一些代码来确保这些需求能够并发地执行。

  • 协程是采用async 关键字所定义的函数。如果你想执行这个协程,但并不要求立刻就获得执行结果,而是稍后再来获取,那么可以通过 await关键字表达这个意思。
  • 协程能够制造出这样一种效果,让人以为程序里有成千上万个函数都在同一时刻高效地运行着。
  • 协程可以用fan-out(分派)与fan-in(归集)模式实现并行的I/O操作,而且能够克服用线程做I/O时的缺陷。

第61条 学会用asyncio改写那些通过线程实现的I/O

  • Python提供了异步版本的for循环、with语句、生成器与推导机制,而且还有很多辅助的库函数,让我们能够顺利地迁移到协程方案。
  • 我们很容易就能利用内置的asyncio模块来改写代码,让程序不要再通过线程执行阻塞式的I/O,而是改用协程来执行异步I/O。

第62条 结合线程与协程,将代码顺利迁移到asyncio

  • asyncio模块的事件循环提供了一个返回awaitable对象的run_in_executor方法,它能够使协程把同步函数放在线程池执行器(ThreadPoolExecutor)里面执行,让我们可以顺利地将采用线程方案所实现的项目,从上至下地迁移到asyncio方案。
  • asyncio模块的事件循环还提供了一个可以在同步代码里面调用的run_until_complete方法,用来运行协程并等待其结束。它的功能跟asyncio.run_coroutine_threadsafe类似,只是后者面对的是跨线程的场合,而前者是为同一个线程设计的。这些都有助于将采用线程方案所实现的项目从下至上地迁移到asyncio方案。

第63条 让asyncio的时间循环保持畅通,以便进一步提升程序的响应能力

  • 把系统调用(包括阻塞式的I/O以及启动线程等操作)放在协程里面执行,会降低程序的响应能力,增加延迟感。
  • 调用asyncio.run时,可以把debug参数设为True,这样能够知道哪些协程降低了事件循环的反应速度。

第64条 考虑用concurrent.futures实现真正的并行计算

  • 把需要耗费大量CPU资源的计算任务改用C扩展模块来写,或许能够有效提高程序的运行速度,同时又让程序里的其他代码依然能够利用Python语言自身的特性。但是,这样做的开销比较大,而且容易引入bug。
  • Python自带的multiprocessing模块提供了许多强大的工具,让我们只需要耗费很少的精力,就可以把某些类型的任务平行地放在多个CPU核心上面处理。
  • 要想发挥出multiprocessing模块的优势,最好是通过concurrent.futures模块及其ProcessPoolExecutor类来编写代码,因为这样做比较简单。
  • 只有在其他方案全都无效的情况下,才可以考虑直接使用multiprocessing里面的高级功能(那些功能用起来相当复杂)。

第八章 稳定与性能

Python有很多内置的特性与模块,可以帮我们加固程序代码,让它应对各种各样的情况。

第65条 合理利用try/except/else/finally结构中的每个代码块

try/finally形式

确保无论某段代码有没有异常,与它配套的清理代码都必须得到执行,同时还想在出现异常的时候,把这个异常向上传播。

1
2
3
4
5
6
7
8
9
def try_finally_example(filename):
print('* Opening file')
handle = open(filename, encoding='utf-8') # May raise OSError
try:
print('* Reading data')
return handle.read() # May raise UnicodeDecodeError
finally:
print('* Calling close()')
handle.close() # Always runs after try block

try/except/else形式

在某段代码发生特定类型异常时,把这种异常向上传播,同时又要在代码没有发生异常的情况下,执行另一端代码。

1
2
3
4
5
6
7
8
9
10
11
12
import json

def load_json_key(data, key):
try:
print('* Loading JSON data')
result_dict = json.loads(data) # May raise ValueError
except ValueError as e:
print('* Handling ValueError')
raise KeyError(key) from e
else:
print('* Looking up key')
return result_dict[key] # May raise KeyError

完整的try/except/else/finally形式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
UNDEFINED = object()
DIE_IN_ELSE_BLOCK = False

def divide_json(path):
print('* Opening file')
handle = open(path, 'r+') # May raise OSError
try:
print('* Reading data')
data = handle.read() # May raise UnicodeDecodeError
print('* Loading JSON data')
op = json.loads(data) # May raise ValueError
print('* Performing calculation')
value = (
op['numerator'] /
op['denominator']) # May raise ZeroDivisionError
except ZeroDivisionError as e:
print('* Handling ZeroDivisionError')
return UNDEFINED
else:
print('* Writing calculation')
op['result'] = value
result = json.dumps(op)
handle.seek(0) # May raise OSError
if DIE_IN_ELSE_BLOCK:
import errno
import os
raise OSError(errno.ENOSPC, os.strerror(errno.ENOSPC))
handle.write(result) # May raise OSError
return value
finally:
print('* Calling close()')
handle.close() # Always runs
  • try/finally形式的复合语句可以确保,无论try块是否抛出异常,finally块都会得到运行。
  • 如果某段代码应该在前一段代码顺利执行之后加以运行,那么可以把它放到else块里面,而不要把这两段代码全都写在try块之中。这样可以让try块更加专注,同时也能够跟except块形成明确对照:except块写的是try块没有顺利执行时所要运行的代码。
  • 如果你要在某段代码顺利执行之后多做一些处理,然后再清理资源,那么通常可以考虑把这三段代码分别放在try、else与finally块里。

第66条 考虑用contextlib与with语句来改写可复用的try/finally代码

with语句可以用来强调某段代码需要在特殊情境之中执行。与相应的try/finally结构是一个意思,但写起来比较方便。

如果想让其他对象与函数也可以这样用在with里面,可以用内置的contextlib模块实现。这个模块提供了contextmanager修饰器,比实现__enter__与__exit__特殊方法的标准做法简单。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import logging
from contextlib import contextmanager

def my_function():
logging.debug('Some debug data')
logging.error('Error')
logging.debug('More debug data')

@contextmanager
def debug_logging(level):
logger = logging.getLogger()
old_level = logger.getEffectiveLevel()
logger.setLevel(level)
try:
yield
finally:
logger.setLevel(old_level)

with debug_logging(logging.DEBUG):
my_function()

带目标的with语句

with … as …,可以把情境管理器所返回的对象赋给as右侧的局部变量。例如打开文件。

如果想让自己的函数也支持这个结构:

1
2
3
4
5
6
7
8
9
10
11
12
@contextmanager
def log_level(level, name):
logger = logging.getLogger(name)
old_level = logger.getEffectiveLevel()
logger.setLevel(level)
try:
yield logger
finally:
logger.setLevel(old_level)

with log_level(logging.DEBUG, 'my_log') as logger:
logger.debug('debug message')

第67条 用datetime模块处理本地时间,不要用time模块

协调世界时(UTC)是标准的时间表示方法,不依赖于时区。但UTC不太直观,很多程序中要涉及到时区转换。

time模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import time

now = 1552774475
local_tuple = time.localtime(now)
time_format = '%Y-%m-%d %H:%M:%S'
time_str = time.strftime(time_format, local_tuple)
print(time_str)

time_tuple = time.strptime(time_str, time_format)
utc_now = time.mktime(time_tuple)
print(utc_now)

import os

if os.name == 'nt':
print("This example doesn't work on Windows")
else:
parse_format = '%Y-%m-%d %H:%M:%S %Z'
depart_sfo = '2019-03-16 15:45:16 PDT'
time_tuple = time.strptime(depart_sfo, parse_format)
time_str = time.strftime(time_format, time_tuple)
print(time_str)

try:
arrival_nyc = '2019-03-16 23:33:24 EDT'
time_tuple = time.strptime(arrival_nyc, time_format)
except:
logging.exception('Expected')
else:
assert False

该模块的功能依赖具体的平台运作,显得不太可靠。time模块没办法稳定地处理多个时区。

datatime模块

UTC与本地时间转换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from datetime import datetime, timezone

now = datetime(2019, 3, 16, 22, 14, 35)
now_utc = now.replace(tzinfo=timezone.utc)
now_local = now_utc.astimezone()
print(now_local)

>>>
2019-03-17 06:14:35+08:00

time_str = '2019-03-16 15:14:35'
now = datetime.strptime(time_str, time_format)
time_tuple = now.timetuple()
utc_now = time.mktime(time_tuple)
print(utc_now)

>>>
1552720475.0

datatime里有相应的机制,可以把一个时区的本地时间可靠地转化成另一个时区的本地时间(通过tzinfo类与相关的方法)。pytz模块用以补充缺失的时区信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import pytz

arrival_nyc = '2019-03-16 23:33:24'
nyc_dt_naive = datetime.strptime(arrival_nyc, time_format)
eastern = pytz.timezone('US/Eastern')
nyc_dt = eastern.localize(nyc_dt_naive)
utc_dt = pytz.utc.normalize(nyc_dt.astimezone(pytz.utc))
print(utc_dt)

pacific = pytz.timezone('US/Pacific')
sf_dt = pacific.normalize(utc_dt.astimezone(pacific))
print(sf_dt)

nepal = pytz.timezone('Asia/Katmandu')
nepal_dt = nepal.normalize(utc_dt.astimezone(nepal))
print(nepal_dt)

在操纵时间数据的过程中,总是应该使用UTC时间,只有到最后一步才需要转化成本地时间显示出来。

第68条 用copyreg实现可靠的pickle操作

Python内置的pickle模块可以把对象序列化为字节流,也可以把字节流反序列化成对象。经过pickle处理的字节流,只应该在彼此信任的双方之间传输,pickle没有考虑过安全问题。

pickle 模块的主要用途仅仅是让我们能够把对象轻松地序列化成二进制数据。如果想直接使用这个模块来实现比这更为复杂的需求,那么可能就会看到奇怪的结果。

解决这样的问题,也非常简单,即可以用内置的 copyreg 模块解决。这个模块允许我们向系统注册相关的函数,把 Python 对象的序列化与反序列化操作交给那些函数去处理这样的话,pickle 模块就运作得更加稳定了。

给新属性设定默认值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class GameState:
def __init__(self, level=0, lives=4, points=0):
self.level = level
self.lives = lives
self.points = points

def pickle_game_state(game_state):
kwargs = game_state.__dict__
return unpickle_game_state, (kwargs,)

def unpickle_game_state(kwargs):
return GameState(**kwargs)

import copyreg

copyreg.pickle(GameState, pickle_game_state)

state = GameState()
state.points += 1000
serialized = pickle.dumps(state)
state_after = pickle.loads(serialized)

用版本号标注同一个类的不同定义

有时候我们要做的改动是无法向后兼容的。这个问题可以通过添加版本号解决。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class GameState:
def __init__(self, level=0, points=0, magic=5):
self.level = level
self.points = points

try:
pickle.loads(serialized)
except:
logging.exception('Expected')
else:
assert False

def pickle_game_state(game_state):
kwargs = game_state.__dict__
kwargs['version'] = 2
return unpickle_game_state, (kwargs,)

def unpickle_game_state(kwargs):
version = kwargs.pop('version', 1)
if version == 1:
del kwargs['lives']
return GameState(**kwargs)

正确处理类名变化

这个问题还是可以用 copyreg 解决。如果我们先通过copyreg.pickle注册了负责执行pickle操作的函数,那么系统在把对象转换成序列化数据的时候,就不会将当前的类名作为引入路径写到数据里面了,而是会把那个函数所指定的unpickle函数(即unpickle_game_state)作为引入路径写进去。所以,不管这个类将来改成什么名字,这份数据都可以通过unpickle_game_state 还原,只要那个函数在还原的时候,使用的是正确的类名就行。这样相当于多了一个步骤,以前是直接根据类名还原,现在是先把还原任务交给unpickle函数,然后由那个函数去决定应该还原成哪个类的对象。

1
2
3
4
5
6
7
8
9
10
class BetterGameState:
def __init__(self, level=0, points=0, magic=5):
self.level = level
self.points = points
self.magic = magic

copyreg.pickle(BetterGameState, pickle_game_state)

state = BetterGameState()
serialized = pickle.dumps(state)

只有一个地方要注意,就是负责 unpickle 的 unpickle_game_state 函数所在模块的路径不能变,因为这个函数名已经写到了序列化之后的数据里面,将来反序列化的时候,系统要先找到这个函数,然后才能通过它正确地还原对象。

第69条 在需要准确计算的场合,用decimal表示相应的数值

Python整数类型实际上可以表示任意尺寸的整型数据,双精度浮点数类型遵循IEEE 754规范。另外,还提供了标准的复数类型。但是,依然存在问题,例如给国际长途计费:

1
2
3
4
5
6
7
rate = 1.45
seconds = 3*60 + 42
cost = rate * seconds / 60
print(cost)

>>>
5.364999999999999

这个数字比正确答案少了一点点,所以采用浮点数算出的结果可能跟实际结果稍有偏差。另外,如果要四舍五入到分,也会出现5.37与5.36的区别。

这样的计算应该用Python内置的decimal模块的Decimal类来做。这个类默认支持28位小数,且还可以更高。

1
2
3
4
5
6
7
8
9
from decimal import Decimal

rate = Decimal('1.45')
seconds = Decimal(3*60 + 42)
cost = rate * seconds / Decimal(60)
print(cost)

>>>
5.365

Decimal初始值可以用两种办法指定。一是str,二是float或者int,这两种办法在某些小数上会产生不同的效果,对于整数则都相同:

1
2
3
4
5
6
print(Decimal('1.45'))
print(Decimal(1.45))

>>>
1.45
1.4499999999999999555910790149937383830547332763671875

Decimal类提供了quantize函数,可以根据指定的舍入方式把数值调整到某一位。

1
2
3
4
5
6
7
8
9
10
11
from decimal import ROUND_UP

rounded = cost.quantize(Decimal('0.01'), rounding=ROUND_UP)
print(f'Rounded {cost} to {rounded}')

rounded = small_cost.quantize(Decimal('0.01'), rounding=ROUND_UP)
print(f'Rounded {small_cost} to {rounded}')

>>>
Rounded 5.365 to 5.37
Rounded 0.004166666666666666666666666667 to 0.01

对于小数点无限的值,依然会有误差。可以用内置的fractions里面的Fraction类来精确表示。

第70条 先分析性能,然后再优化

Python的动态机制,让我们很难预判程序在运行时的性能。有些操作,看上去似乎比较慢,但实际执行起来却很快(例如操纵字符串,使用生成器等);还有一些操作,看上去似乎比较快,但实际执行起来却很慢(例如访问属性,调用函数等)。让Python程序速度变慢的原因,有时很难观察出来。

所以,最好不要凭感觉去判断,而是应该先获得具体的测评数据,然后再决定怎么优化。Python内置了profiler模块,可以找到程序里面占总执行时间比例最高的一部分,这样的话,我们就可以专心优化这部分代码,而不用执着于对程序性能影响不大的那些地方(因为你把同样的精力投入到那些地方,产生的提速效果不会太好)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# 待测评的插入排序
def insertion_sort(data):
result = []
for value in data:
insert_value(result, value)
return result

def insert_value(array, value):
for i, existing in enumerate(array):
if existing > value:
array.insert(i, value)
return
array.append(value)

# 创建一套随机的测试数据
from random import randint

max_size = 10**4
data = [randint(0, max_size) for _ in range(max_size)]
test = lambda: insertion_sort(data)


# 构建性能分析器
from cProfile import Profile

profiler = Profile()
profiler.runcall(test)


# 统计结果
from pstats import Stats

stats = Stats(profiler)
stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_stats()

>>>
20003 function calls in 2.690 seconds

Ordered by: cumulative time

ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 2.690 2.690 effective python.py:23(<lambda>)
1 0.004 0.004 2.690 2.690 effective python.py:2(insertion_sort)
10000 2.676 0.000 2.686 0.000 effective python.py:10(insert_value)
9996 0.010 0.000 0.010 0.000 {method 'insert' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4 0.000 0.000 0.000 0.000 {method 'append' of 'list' objects}

有的时候,在评测完整个程序后,可能会发现大部分时间都耗在了某个常见的工具函数上,单凭默认的统计结果,很难看出性能瓶颈在哪。这时候可以用print_callers方法打印统计结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def my_utility(a, b):
c = 1
for i in range(100):
c += a * b

def first_func():
for _ in range(1000):
my_utility(4, 5)

def second_func():
for _ in range(10):
my_utility(1, 3)

def my_program():
for _ in range(20):
first_func()
second_func()

profiler = Profile()
profiler.runcall(my_program)
stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_stats()

stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_callers()

>>>
Ordered by: cumulative time

Function was called by...
ncalls tottime cumtime
effective python.py:49(my_program) <-
effective python.py:41(first_func) <- 20 0.005 0.168 effective python.py:49(my_program)
effective python.py:36(my_utility) <- 20000 0.163 0.163 effective python.py:41(first_func)
200 0.001 0.001 effective python.py:45(second_func)
effective python.py:45(second_func) <- 20 0.000 0.002 effective python.py:49(my_program)
{method 'disable' of '_lsprof.Profiler' objects} <-

第71条 优先考虑用deque实现生产者—消费者队列

写程序的时候,经常要用到先进先出的( first-in,first-out,FIFO)队列,这种队列也叫作生产者—消费者队列(producer-consumer queue)或生产—消费队列。FIFO队列可以把某个函数给出的值收集起来,并交给另一个函数按序处理。一般来说,开发者会用Python内置的list类型来实现FIFO队列。

但是,当列表中的元素变多之后,list的性能就会下降。用pop(0)从队列开头移除元素所花的时间竟然跟队列长度呈反比关系。为什么会这样呢?这是因为,用pop(0)删掉开头元素时,需要把后面的每个元素都向前移动一个位置,这相当于把整份列表的内容都改掉。消费函数需要针对list中的所有元素(共有len (queue)个)都做一次pop(0),而每次pop(0)又需要执行大约Ien(queue)次移动,所以总的操作次数就是len(queue)*len(queue)次。这种算法不能够大规模运用。

Python内置的collections模块里面有个 deque类,可以解决这个问题。这个类所实现的是双向队列( double—ended queue),从头部执行插入或尾部执行删除操作,都只需要固定的时间,所以它很适合充当 FIFO 队列。Python内置的集合模块里面有个deque类,可以解天这个问题。这个类所实现的是双向队列(双端队列),从头部执行插人或尾部执行删除操作,都只需要固定的时间,所以它很适合充当FIFO队列。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import collections

def consume_one_email(queue):
if not queue:
return
email = queue.popleft() # Consumer
# Process the email message
print(f'Consumed email: {email.message}')

def my_end_func():
pass

my_end_func = make_test_end()
EMAIL_IT = get_emails()
loop(collections.deque(), my_end_func)

第72条 考虑用bisect搜索已排序队列

Python内置的bisect模块可以更好地搜索有序列表。其中的bisect_left函数,能够迅速地对任何一个有序的序列执行二分搜索。如果序列中有这个值,那么它返回的就是跟这个值相等的头一个元素所在的位置;如果没有,那么它返回的是插入位置,也就是说,把待查的值插到这个位置可以让序列继续保持有序。

bisect模块的二分搜索算法,在复杂度上面是对数级别的。这意味着,线性搜索算法(list.index方法)在包含20个元素的列表中查询目标值所花的时间,已经够这个算法搜索长度为一百万个元素的列表了(math. log2(10**6)大约是19.93)。它要比线性搜索快得多!

1
2
3
4
5
6
7
8
9
from bisect import bisect_left

data = list(range(10**5))

index = bisect_left(data, 91234)
assert index == 91234

index = bisect_left(data, 91234.56)
assert index == 91235

bisect最好的地方,是它不局限于list类型,而是可以用在任何一种行为类似序列对象上面(怎样让其他对象也表现出跟序列相似的行为,请参见第43条)。bisect模块还提供了其他一些功能,可以实现更为高级的用法。

第73条 学会使用heapq制作优先级队列

Python中其它队列都是先进先出队列,会按照接收元素的顺序来保存这些元素。但有的时候,我们想要根据元素的重要程度来排序,这种情况应该用优先级队列。

直接用list实现,对于数据量少的情况或许还行,但肯定无法应对大规模数据。好在Python内置的heapq模块可以解决这个问题,因为它能够高效地实现出优先级队列。模块名称里面的heap指的是堆,这是一种数据结构,可以维护列表中的元素,并且只需要对数级别的时间就可以添加新元素或移除其中最小的元素。

下面用heapq模块实现图书馆管理程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from heapq import heappush

def add_book(queue, book):
heappush(queue, book)

try:
queue = []
add_book(queue, Book('Little Women', '2019-06-05'))
add_book(queue, Book('The Time Machine', '2019-05-30'))
except:
logging.exception('Expected')
else:
assert False

import functools

@functools.total_ordering
class Book:
def __init__(self, title, due_date):
self.title = title
self.due_date = due_date

def __lt__(self, other):
return self.due_date < other.due_date

queue = []
add_book(queue, Book('Pride and Prejudice', '2019-06-01'))
add_book(queue, Book('The Time Machine', '2019-05-30'))
add_book(queue, Book('Crime and Punishment', '2019-06-06'))
add_book(queue, Book('Wuthering Heights', '2019-06-12'))
print([b.title for b in queue])

queue = [
Book('Pride and Prejudice', '2019-06-01'),
Book('The Time Machine', '2019-05-30'),
Book('Crime and Punishment', '2019-06-06'),
Book('Wuthering Heights', '2019-06-12'),
]
queue.sort()
print([b.title for b in queue])

# 也可以不调用sort,而是改用heapq.heapify来构造这个堆。
from heapq import heapify

queue = [
Book('Pride and Prejudice', '2019-06-01'),
Book('The Time Machine', '2019-05-30'),
Book('Crime and Punishment', '2019-06-06'),
Book('Wuthering Heights', '2019-06-12'),
]
heapify(queue)
print([b.title for b in queue])

>>>
['The Time Machine', 'Pride and Prejudice', 'Crime and Punishment', 'Wuthering Heights']
  • 优先级队列让我们能够按照重要程度来处理元素,而不是必须按照先进先出的顺序处理。
  • 如果直接用相关的列表操作来模拟优先级队列,那么程序的性能会随着队列长度的增大而大幅下降,因为这样做的复杂程度是平方级别,而不是线性级别。
  • 通过Python内置的heapq模块所提供的函数,我们完全可以实现基于堆的优先级队列,从而高效地处理大量数据。
  • 要使用heapq模块,我们必须让元素所在的类型支持自然排序,这可以通过对类套用@functools.total_ordering 修饰器并定义__lt__方法来实现。

第74条 考虑用memoryview与bytearray来实现无需拷贝的bytes操作

对bytes实例做切片需要拷贝底层数据,这会浪费 CPU的时间。这段代码可以通过Python内置的memoryview类型来改进,这个类型让程序能够利用CPython的缓冲协议(buffer protocol)高效地操纵字节数据。这套协议属于底层的C API,允许Python运行时系统与C扩展访问底层的数据缓冲,而bytes等实例正是由这种数据缓冲对象所支持的。memoryview最大的优点,是能够在不复制底层数据的前提下,制作出另一个memoryview。下面这段代码,就先把bytes实例封装在memoryview里面,然后再切割,这样不用拷贝底层数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
data = b'shave and a haircut, two bits'
view = memoryview(data)
chunk = view[12:19]
print(chunk)
print('Size: ', chunk.nbytes)
print('Data in view: ', chunk.tobytes())
print('Underlying data:', chunk.obj)

>>>
<memory at 0x0000018512167B80>
Size: 7
Data in view: b'haircut'
Underlying data: b'shave and a haircut, two bits'

由于它执行的是零拷贝操作,因此对需要高速处理大量内存数据的代码来说用处很大。

而bytearray则相当于可修改的bytes,它允许我们修改任意位置上面的内容。bytearray 采用整数表示其中的内容,而不像bytes那样,采用b开头的字面值。

bytearray与bytes一样,也可以用memoryview封装,在这种memoryview上面切割出来的对象,其内容可以用另一份数据替换,这样做,替换的实际上是 memoryview背后那个底层缓冲区里面的相应部分。这使得我们可以通过memoryview来修改它所封装的bytearray,而不像刚才那样,必须先将bytes拆散,然后再拼起来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
my_array = bytearray(b'hello')
my_array[0] = 0x79
print(my_array)

>>>
bytearray(b'yello')

my_array = bytearray(b'row, row, row your boat')
my_view = memoryview(my_array)
write_view = my_view[3:13]
write_view[:] = b'-10 bytes-'
print(my_array)

>>>
bytearray(b'row-10 bytes- your boat')

Python里面很多库之中的方法,例如 socket.recv_into 与 RawIOBase. readinto,都使用缓冲协议来迅速接收或读取数据。这种方法的好处是不用分配内存,也不用给原数据制作复本,它们会把收到的内容直接写入现有的缓冲区。

第九章 测试与调试

第75条 通过repr字符串输出调试信息

  • 把内置类型的值传给print,会打印出便于认读的那种字符串,但是其中不会包含类型信息。
  • 把内置类型的值传给repr,会得到一个能够表示该值的可打印字符串,将这个reepr字符串传给内置的eval函数能够得到原值。
  • 在格式化字符串里用%s处理相关的值,就跟把这个值传给str函数一样,都能得到一个便于认读的那种字符串。如果用%r来处理,那么得到的就是repr字符串。在f-string中,也可以用值来取代其中有待替换的那一部分,并产生便于认读的那种字符串,但如果待替换的部分加了!r后缀,那么替换出来的就是repr字符串。
  • 给类定义__repr__特殊方法,可以让print函数把该类实例的可打印表现展现出来,在实现这个方法时,还可以提供更为详尽的调试信息。
1
2
3
4
5
6
7
8
9
10
11
12
13
int_value = 5
str_value = '5'
print(f'{int_value} == {str_value} ?')

>>>
5 == 5 ?

print(repr(5))
print(repr('5'))

>>>
5
'5'

第76条 在TestCase子类里验证相关行为

在Python里编写测试最经典办法是使用内置的unittest模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from unittest import TestCase, main
from utils import to_str

class UtilsTestCase(TestCase):
def test_to_str_bytes(self):
self.assertEqual('hello', to_str(b'hello'))

def test_to_str_str(self):
self.assertEqual('hello', to_str('hello'))

def test_failing(self):
self.assertEqual('incorrect', to_str('hello'))

if __name__ == '__main__':
main()

>>>
Ran 3 tests in 0.020s

FAILED (failures=1)


hello != incorrect

预期:incorrect
实际:hello
<点击以查看差异>

Traceback (most recent call last):
File "XXX.py", line 19, in test_failing
self.assertEqual('incorrect', to_str('hello'))
AssertionError: 'incorrect' != 'hello'
- incorrect
+ hello

测试用例需要安排到TestCase的子类中。在这样的子类中,每个以 test 开头的方法都表示一项测试用例。如果test方法在运行过程中没有抛出任何异常(assert语句所触发的AssertionError也算异常),那么这项测试用例就是成功的,否则就是失败。其中一项测试用例失败,并不影响系统继续执行TestCase子类里的其他 test方法,所以我们最后能够看到总的结果,知道其中有多少项测试用例成功,多少项失败,而不是只要遇到测试用例失败,就立刻停止整套测试。

在修改了软件产品中的某个方法之后,我们可能想把针对该方法而写的测试用例迅速执行一遍,看自己改得对不对。在这种情况下,可以把TestCase子类的名称与test方法的名字直接写在原有的命令右边。

1
python3 utils_test.py UtilsTestCase.test_to_str_bytes

TestCase类里提供了一些辅助方法,可以在测试用例里做断言:assertEqual可以确认二者是否相等;assertTrue确认Boolean表达式是否为True;assertRaises验证结构的主体部分是否会抛出应有的异常。

如果测试用例需要使用比较复杂的逻辑,那么可以把这些逻辑定义成辅助方法放到TestCase子类里。但是必须注意,这种方法的名称不能以test开头,否则系统就会把它们当成测试用例来执行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from unittest import TestCase, main

def sum_squares(values):
cumulative = 0
for value in values:
cumulative += value ** 2
yield cumulative

class HelperTestCase(TestCase):
def verify_complex_case(self, values, expected):
expect_it = iter(expected)
found_it = iter(sum_squares(values))
test_it = zip(expect_it, found_it)

for i, (expect, found) in enumerate(test_it):
self.assertEqual(
expect,
found,
f'Index {i} is wrong')

# Verify both generators are exhausted
try:
next(expect_it)
except StopIteration:
pass
else:
self.fail('Expected longer than found')

try:
next(found_it)
except StopIteration:
pass
else:
self.fail('Found longer than expected')

def test_wrong_lengths(self):
values = [1.1, 2.2, 3.3]
expected = [
1.1**2,
]
self.verify_complex_case(values, expected)

def test_wrong_results(self):
values = [1.1, 2.2, 3.3]
expected = [
1.1**2,
1.1**2 + 2.2**2,
1.1**2 + 2.2**2 + 3.3**2 + 4.4**2,
]
self.verify_complex_case(values, expected)

if __name__ == '__main__':
main()

>>>
Failure
Traceback (most recent call last):
File "XXX.py", line 41, in test_wrong_lengths
self.verify_complex_case(values, expected)
File "XXX.py", line 34, in verify_complex_case
self.fail('Found longer than expected')
AssertionError: Found longer than expected



Ran 2 tests in 0.009s

FAILED (failures=2)

Index 2 is wrong
16.939999999999998 != 36.3

预期:36.3
实际:16.939999999999998
<点击以查看差异>

Traceback (most recent call last):
File "XXX.py", line 50, in test_wrong_results
self.verify_complex_case(values, expected)
File "XXX.py", line 16, in verify_complex_case
self.assertEqual(
AssertionError: 36.3 != 16.939999999999998 : Index 2 is wrong

可以把相关的测试归为一组,并针对每组测试定义相应的TestCase子类。如果某个函数有许多种边界情况要测,那么笔者喜欢专门针对这个函数定义一个TestCase子类,而对那些比较简单的函数,笔者则喜欢把同一个模块里的这些函数全都放在同一个TestCase子类中。另外,笔者喜欢针对每个基本的类单独创建对应的TestCase子类,以测试该类及类中的各个方法。TestCase类还提供了subTest辅助方法,可以让我们把相似的用例全都写在同一个test方法中,让它们成为这个用例中的子用例,这样的话,每个子用例所共用的那部分代码与逻辑只需要写一次就行。对由数据所驱动的测试来说,这个辅助方法尤其有用,因为其中一条数据(也就是一项子用例)测试失败,并不影响后面的数据(也就是后面的那些子用例)继续接受测试(这与TestCase子类里的那些test方法一样, 即便其中有某个test方法测试失败,其他的test方法也还是可以继续接受测试)。下面定义一套数据,演示如何通过subTest方法做数据驱动测试。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from unittest import TestCase, main
from utils import to_str

class DataDrivenTestCase(TestCase):
def test_good(self):
good_cases = [
(b'my bytes', 'my bytes'),
('no error', b'no error'), # This one will fail
('other str', 'other str'),
]
for value, expected in good_cases:
with self.subTest(value):
self.assertEqual(expected, to_str(value))

def test_bad(self):
bad_cases = [
(object(), TypeError),
(b'\xfa\xfa', UnicodeDecodeError),
]
for value, exception in bad_cases:
with self.subTest(value):
with self.assertRaises(exception):
to_str(value)

if __name__ == '__main__':
main()

>>>
SubTest failure: Traceback (most recent call last):
File "C:\...\unittest\case.py", line 57, in testPartExecutor
yield
File "C:\...\unittest\case.py", line 538, in subTest
yield
File "XXX.py", line 13, in test_good
self.assertEqual(expected, to_str(value))
File "C:\...\diff_tools.py", line 33, in _patched_equals
old(self, first, second, msg)
AssertionError: b'no error' != 'no error'



One or more subtests failed
Failed subtests list: [no error]


Ran 2 tests in 0.015s

FAILED (failures=1)

如果项目比较复杂,可以用pytest包。

第77条 把测试前、后的准备与清理逻辑写在setUp、tearDown、setUp-Module与tearDown-Module中,以防用例之间相互干扰

我们可以在TestCase子类中覆写setUp与tearDown方法,并把相应的准备逻辑与清理逻辑写在里面。系统在执行每个test方法之前都会先调用一遍setUp方法,并在执行完test方法之后调用一遍tearDown方法。这可以确保测试用例之间不会互相干扰,这一点,对测试工作至关重要。

例如,我们可以像下面这样把创建临时目录的逻辑放在setUp方法中,使系统在执行在执行test_modify_file用例之前先把存放 ‘data.bin’ 文件所用的临时目录准备好,并在执行完用例之后,通过tearDown方法清空该目录。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import TestCase, main

class EnvironmentTest(TestCase):
def setUp(self):
self.test_dir = TemporaryDirectory()
self.test_path = Path(self.test_dir.name)

def tearDown(self):
self.test_dir.cleanup()

def test_modify_file(self):
with open(self.test_path / 'data.bin', 'w') as f:
f.write('hello')

if __name__ == '__main__':
main()

当程序变得复杂后,我们就不能只依赖这种彼此隔绝的单元测试了,而是需要再写一些测试,以验证模块与模块之间能否正确地交互(这可能要用到mock等工具,参见第78条)。这种测试叫集成测试(integration test),它跟前面的单元测试(unit test)不同。这两种测试在Python中很重要,假如不做集成测试,那就没办法确信这些模块能够协同运作。

对于集成测试来说,测试环境的准备与清理工作可能要占用大量计算资源,并持续比较长的时间。例如,可能要先启动数据库进程,并等待该进程把索引加载进来,然后才能开始做集成测试。这些工作的延迟很高,因此不能像做单元测试时那样,写在setUp与tearDown 方法中。

unittest模块支持模块级别的测试用具初始化,以解决集成测试的准备与清理问题。这样的话,那些高成本的资源只在setUpModule中初始化一次就好,而不用在每个test方法运行之前都重复初始化一遍。待所有的test方法执行完,会在tearDownModule函数里清理这项资源,当然也只需要清理一次就行。下面我们就在包含TestCase子类的这个模块里定义setUpModule与tearDownModule函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from unittest import TestCase, main

def setUpModule():
print('* Module setup')

def tearDownModule():
print('* Module clean-up')

class IntegrationTest(TestCase):
def setUp(self):
print('* Test setup')

def tearDown(self):
print('* Test clean-up')

def test_end_to_end1(self):
print('* Test 1')

def test_end_to_end2(self):
print('* Test 2')

if __name__ == '__main__':
main()

>>>
* Module setup
* Test setup
* Test 1
* Test clean-up
* Test setup
* Test 2
* Test clean-up
* Module clean-up


Ran 2 tests in 0.003s

OK

第78条 用Mock来模拟受测代码所依赖的复杂函数

写测试的时候还有一个常见的问题,就是某些逻辑很难从开发环境里真实的执行,或者使用起来特别慢,这样的逻辑可以通过mock函数与Mock类来模拟。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from datetime import datetime
from unittest.mock import Mock

class DatabaseConnection:
def __init__(self, host, port):
pass

class DatabaseConnectionError(Exception):
pass

def get_animals(database, species):
# Query the Database
raise DatabaseConnectionError('Not connected')

mock = Mock(spec=get_animals)
expected = [
('Spot', datetime(2019, 6, 5, 11, 15)),
('Fluffy', datetime(2019, 6, 5, 12, 30)),
('Jojo', datetime(2019, 6, 5, 12, 45)),
]
mock.return_value = expected
  • unittest.mock模块中的Mock类能够模拟某个接口的行为,我们可以用它替换受测函数所要调用的接口,因为那些接口可能不太容易在测试的过程中配置。

  • 如果用mock把受测代码所依赖的函数替换掉了,那么在测试的时候,不仅要验证受测代码的行为,而且还要验证它有没有正确地调用这些mock,这可以通过Mock.assert_called_once_with等一系列方法实现。

  • 要想把受测函数所调用的其他函数用mock逻辑替换掉,一种办法是给受测函数设计只能以关键字来指定的参数;另一种办法是通过unittest.mock.patch系列的方法暂时隐藏那些函数。

第79条 把受测代码所依赖的系统封装起来,以便于模拟和测试

上一条(也就是第78条)讲了怎样用Python内置的unittest.mock模块测试需要依赖复杂系统(例如数据库)才能运作的代码。我们当时讲了两套方案,一个是通过Mock类实现,另一个是通过patch方法实现。可是,这两种方案都要求我们在测试的过程中重复编写很多例行代码,这会让初次阅读代码的人很难理解我们究竟要验证什么。

有一种办法可以改进代码,就是把受测函数所要使用的数据库接口封装起来,这样我们就不用像原来那样,专门把数据库连接(DatabaseConnection)当作参数传给受测函数了,而是可以将封装好的系统传过去。这种代码重构通常是很值得采取的,因为这样可以形成更好的抽象层,让我们能够更方便地创建mock逻辑,并用这些仿制的逻辑来编写测试用例(还有一种重构,参见第89条)。下面重新定义受测函数所用到的三个辅助函数,但是这次我们将这些函数放在一个叫作ZooDatabase 的类中,让它们成为该类的方法,而不像原来那样作为独立的函数出现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class ZooDatabase:

def get_animals(self, species):
pass

def get_food_period(self, species):
pass

def feed_animal(self, name, when):
pass

from datetime import datetime

def do_rounds(database, species, *, utcnow=datetime.utcnow):
now = utcnow()
feeding_timedelta = database.get_food_period(species)
animals = database.get_animals(species)
fed = 0

for name, last_mealtime in animals:
if (now - last_mealtime) >= feeding_timedelta:
database.feed_animal(name, now)
fed += 1

return fed

from unittest.mock import Mock

database = Mock(spec=ZooDatabase)
database.feed_animal()
database.feed_animal.assert_any_call()

from datetime import timedelta
from unittest.mock import call

now_func = Mock(spec=datetime.utcnow)
now_func.return_value = datetime(2019, 6, 5, 15, 45)

database = Mock(spec=ZooDatabase)
database.get_food_period.return_value = timedelta(hours=3)
database.get_animals.return_value = [
('Spot', datetime(2019, 6, 5, 11, 15)),
('Fluffy', datetime(2019, 6, 5, 12, 30)),
('Jojo', datetime(2019, 6, 5, 12, 55))
]

result = do_rounds(database, 'Meerkat', utcnow=now_func)
assert result == 2

database.get_food_period.assert_called_once_with('Meerkat')
database.get_animals.assert_called_once_with('Meerkat')
database.feed_animal.assert_has_calls(
[
call('Spot', now_func.return_value),
call('Fluffy', now_func.return_value),
],
any_order=True)
  • 在写单元测试的时候,如果总是要反复使用许多代码来注入模拟的逻辑,那么可以考虑把受测函数所要用到的逻辑封装到类中,因为封装之后更容易注入。
  • Python内置的unittest.mock模块里有个 Mock类,它能模拟类的实例,这种Mock对象具备与原类中的方法相对应的属性。如果在它上面调用某个方法,就会触发相应的的属性。
  • 如果想把程序完整地测一遍,那么可以重构代码,在原来直接使用复杂系统的地方引入辅助函数,让程序通过这些函数来获取它要用的系统,这样我们就可以通过辅助函数注入模拟逻辑。

第80条 考虑用pdb做交互测试

在编写程序的过程中总是会遇到bug。有时,可以通过print函数打印相关的信息,以追查导致程序出错的原因(参见第75条);有时针对特定的情况编写测试用例也可以很明确地把程序所遇到的问题暴露出来(参见第
76条)。

但是,这些手段并不能发现所有的错误,有时我们得求助更强大的工具。Python内置的交互调试器(interactive debugger)就是这样一种工具,它可以检查程序状态打印局部变量的值,还可以每次只执行一条Python语句(也就是单步执行))。

在其他大部分编程语言中,如果要使用调试器,那么必须先在源文件中指定断点,令程序在执行到这一行时停下来。然而 Python不用这样,你可以直接在认为有问题的那行代码前加入一条指令,让程序暂停,并启动调试器,这是最简单的办法。采用这种办法来调试程序,与正常启动程序并没有什么区别。

用来触发调试器的指令是breakpoint函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import math

def compute_rmse(observed, ideal):
total_err_2 = 0
count = 0
for got, wanted in zip(observed, ideal):
err_2 = (got - wanted) ** 2
breakpoint() # Start the debugger here
total_err_2 += err_2
count += 1

mean_err = total_err_2 / count
rmse = math.sqrt(mean_err)
return rmse

result = compute_rmse(
[1.8, 1.7, 3.2, 6],
[2, 1.5, 3, 5])
print(result)

在(Pdb)提示符界面我们可以输入局部变量的名称(或执行p <name>命令)来查看变量的取值。也可以调用Python内置的locals函数以观察所有的局部变量,还可以引入模块,检查全局状态,构造新的对象,或运行内置的help命令,甚至还能修改正在运行的程序里的某些部分。总之,对调试工作有帮助的操作都可以在这里执行。

另外,调试器还提供了各种特殊命令,帮我们控制程序的执行方式,并探查其执行情况。在调试界面输入help,可以看到完整的命令列表。

通过下面这三条非常实用的命令,我们可以很方便地检查正在运行的这个程序:

  • where:打印出当前的执行调用栈(execution call stack),可以据此判断程序当前执行到了哪个位置,以及程序是在调用了哪些函数后才触发breakpoint断点的。
  • up:把观察点沿着执行调用栈上移一层,回到当前函数调用者处,以观察位于当前断点之上的那些层面分别有什么样的局部变量。
  • down:把观察点沿着执行调用栈下移一层。

检查完程序的运行状态后,可以通过下面这五条命令决定程序接下来应该如何执行:

  • step:执行程序里的下一行代码,并在执行完毕后把控制权交还给调试器。如果下一行代码带有函数调用操作,那么调试器就会停在受调用的那个函数开头。
  • next :执行当前函数的下一行代码,并在执行完毕后,返回交互调试界面。如果下一行代码带有函数调用操作,系统不会令调试器停在受调用的函数开头。
  • return:让程序一直运行到当前函数返回为止,然后把控制权交还给调试器。
  • continue :让程序运行到下一个断点处(那个断点可以是通过breakpoint触发的,也可以是在调试界面里设置的)。
  • quit:退出调试界面,并且让接受调试的程序也随之终止。如果已经找到了问题,那么就可以用这个命令结束调试。如果发现寻找的方向不对,或者需要先去修改程序的代码,那么也应该运行这个命令以便重新调试。

breakpoint函数可以出现在程序里的任何地方。

调试器还支持一项有用的功能,叫作事后调试(post-mortem debugging),当我们发现程序会抛出异常并崩溃后,想通过调试器看看它在抛出异常的那一刻,究竟是什么样子的。有时我们也不确定应该在哪里调用 breakpoint函数,在这种情况下,尤其需要这项功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import math

def compute_rmse(observed, ideal):
total_err_2 = 0
count = 0
for got, wanted in zip(observed, ideal):
err_2 = (got - wanted) ** 2
total_err_2 += err_2
count += 1

mean_err = total_err_2 / count
rmse = math.sqrt(mean_err)
return rmse

result = compute_rmse(
[1.8, 1.7, 3.2, 7j], # Bad input
[2, 1.5, 3, 5])
print(result)
1
python3 -m pdb -c continue <program path>

第81条 用tracemalloc来掌握内存的使用与泄露情况

在Python的默认实现方式(也就是CPython)中,内存管理是通过引用计数(referencecounting)执行的。如果指向某个对象的引用已经全部过期,那么受引用的对象就可以从内存中清除,从而给其他数据腾出空间。另外,CPython还内置了循环检测器(cycle detector),确保那些自我引用的对象也能够得到清除。

从理论上讲,这意味着Python开发者不用担心程序如何分配并释放内存的问题,因为Python系统本身以及CPython运行时环境会自动处理这些问题。但实际上,还是会有程序因为没有及时释放不再需要引用的数据而耗尽内存。想了解Python程序使用内存的情况,或找到泄漏内存的原因,是比较困难的。

第一种调试内存使用状况的办法,是用Python内置的gc模块把垃圾回收器目前知道的每个对象都列出来。虽然这样有点儿笨,但毕竟可以让我们迅速得知程序的内存使用状况。

下面先定义这样一个准备接受测试的模块,让它生成一些对象并加以引用,从而令这些对象能够占据一定空间。然后运行,再打印。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os

class MyObject:
def __init__(self):
self.data = os.urandom(100)

def get_data():
values = []
for _ in range(100):
obj = MyObject()
values.append(obj)
return values

def run():
deep_values = []
for _ in range(100):
deep_values.append(get_data())
return deep_values

import gc

found_objects = gc.get_objects()
print('Before:', len(found_objects))

hold_reference = run()

found_objects = gc.get_objects()
print('After: ', len(found_objects))
for obj in found_objects[:3]:
print(repr(obj)[:100])

print('...')

>>>
Before: 7750
After: 17802
<__main__.MyObject object at 0x000001E57D1D7B90>
<__main__.MyObject object at 0x000001E57D1D7BD0>
<__main__.MyObject object at 0x000001E57D1D7C10>
...

gc.get_objects函数的缺点在于,它并没有指出这些对象究竟要如何分配。在比较复杂的程序中,同一个类的对象可能是因为好几种不同的原因而为系统所分配的。知道对象的总数固然有意义,但更为重要的是找到分配这些对象的具体代码,这样才能查清内存泄漏的原因。

Python 3.4版本推出了一个新的内置模块,名为tracemalloc,它可以解决刚才讲的那个问题。tracemalloc能够追溯对象到分配它的位置,因此我们可以在执行受测模块之前与执行完毕之后,分别给内存使用情况做快照,并对比两份快照,以了解它们之间的区别。

下面我们就用这个方法把受测程序中分配内存最多的那三处找出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import tracemalloc

tracemalloc.start(10) # Set stack depth
time1 = tracemalloc.take_snapshot() # Before snapshot

x = run() # Usage to debug
time2 = tracemalloc.take_snapshot() # After snapshot

stats = time2.compare_to(time1, 'lineno') # Compare snapshots
for stat in stats[:3]:
print(stat)

>>>
C:\XXX.py:5: size=1299 KiB (+1299 KiB), count=10000 (+10000), average=133 B
C:\XXX.py:10: size=785 KiB (+785 KiB), count=20000 (+20000), average=40 B
C:\XXX.py:11: size=84.4 KiB (+84.4 KiB), count=100 (+100), average=864 B

tracemalloc还可以打印完整的栈追踪信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tracemalloc

tracemalloc.start(10)
time1 = tracemalloc.take_snapshot()

x = run()
time2 = tracemalloc.take_snapshot()

stats = time2.compare_to(time1, 'traceback')
top = stats[0]
print('Biggest offender is:')
print('\n'.join(top.traceback.format()))

>>>
Biggest offender is:
File "C:\XXX.py", line 25
x = run()
File "C:\XXX.py", line 17
deep_values.append(get_data())
File "C:\XXX.py", line 10
obj = MyObject()
File "C:\XXX.py", line 5
self.data = os.urandom(100)

第十章 协作开发

第82条 学会寻找由其他Python开发者所构建的模块

Python有个集中存放模块的地方,叫Python Package Index(PyPI),网址http://pypi.org,可以从中安装模块。可以用pip命令行工具安装软件包。

第83条 用虚拟环境隔离项目,并重建依赖关系

pip会将新的安装包默认安装到全局路径之中,这会让每个涉及该模块的程序都受到影响。在间接的依赖关系中会出现问题。这个问题可以用虚拟环境解决:

默认安装的内置工具venv。

创建:

1
python3 -m venv <env_name>

启用/禁用:

1
2
source bin/activate
source bin/deactivate

把当前环境所依赖的包保存到文件中,再重新安装:

1
2
python3 -m pip freeze > requirements.txt
python3 -m pip install -r requirements.txt

Anaconda

网址:https://www.anaconda.com

笔者更推荐使用Anaconda

第84条 每一个函数、类与模块都要写docstring

Python允许我们在程序运行过程中,直接访问这些文档:

1
2
3
4
5
6
7
8
def palindrome(word):
"""Return True if the given word is a palindrome."""
return word == word[::-1]

print(repr(palindrome.__doc__))

>>>
'Return True if the given word is a palindrome.'

编写文档的优点:

  • 开发者能够在程序中访问文档信息,这会让交互式开发工作变得更加轻松。可以用内置的help函数查看与某个函数、类及模块相对应的文档。无论是基本的Python解释器界面(也就是默认的Python shell),还是IPython Notebook这样的高级工具,都可以相当方便地查询文档,这让我们能够愉快地研究算法、测试API并编写代码片段。
  • 这些文档是按照标准的方式定义的,因此很容易就能转换成表现力更强的格式(例如HTML)。这也促使Python开发者推出Sphinx等优秀的文档生成工具,另外还有像Read the Docs这样由开发者社群所赞助的网站可以为开源的Python项目免费存放美观的文档。
  • Python文档不仅可以做得很漂亮,而且与其他普通的头等Python实体一样,也能够在程序里面正常地访问,这会让开发者更乐意编写这样的文档。许多Python开发者都坚信,文档是很重要的。有人认为,如果一段代码能称得上好代码,那么其中的文档肯定也写得不错。所以,很多优秀的开源Python项目里面,应该都有比较好的文档。

为模块编写文档

每个模块都要有顶级的docstring,即写在源文件开头的那个字符串。字符串的首尾都要带三重引号,这样的字符串的目的主要是介绍本模块与其中的内容。

在 docstring里面,第一行应是一个单句,描述本模块的用途。接下来应该另起一段,详解讲述使用这个模块的用户所要知道的一些事项。另外,凡是模块里面比较重要的类与函数,都应该在docstring 中予以强调,这样的话,查看这份文档的用户就可以从这些类及函数出发来熟悉模块。

下面举个例子,讲解如何为模块编写docstring。

1
2
3
4
5
6
7
8
9
10
11
"""Library for finding linguistic patterns in words.

Testing how words relate to each other can be tricky sometimes!
This module provides easy ways to determine when words you've
found have special properties.

Available functions:
- palindrome: Determine if a word is a palindrome.
- check_anagram: Determine if two words are anagrams.
...
"""

为类编写文档

每个类都应该有类级别的 docstring,这种文档的写法,与模块级别的docstring差不多。它的第一段,也需要用一句话来概述整个类的用途。后面的各段,可以详细讲解本类中的每一种操作。

类中比较重要的public属性与方法,同样应该在类级别的 docstring里面加以强调。另外还需要说明,如果想编写子类,子类应该怎样与受保护的属性(参见第42条)以及超类中的方法相交互。

下面演示类的 docstring应该如何编写。

1
2
3
4
5
6
7
8
9
10
11
class Player:
"""Represents a player of the game.

Subclasses may override the 'tick' method to provide
custom animations for the player's movement depending
on their power level, etc.

Public attributes:
- power: Unused power-ups (float between 0 and 1).
- coins: Coins found during the level (integer).
"""

为函数编写文档

每个 public函数与方法都应该有docstring。它的写法与模块和类的相同,第一段也是一个句子,描述这个函数是做什么的。接下来的那段应该描述函数的行为。然后,可以各用一段来描述函数的参数与返回值。另外,如果调用者在使用这个函数接口的时候,需要处理该函数所抛出的一些异常,那么这些异常也要解释(参见第20条)。

下面举例说明如何为函数编写 docstring。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import itertools

def find_anagrams(word, dictionary):
"""Find all anagrams for a word.

This function only runs as fast as the test for
membership in the 'dictionary' container.

Args:
word: String of the target word.
dictionary: collections.abc.Container with all
strings that are known to be actual words.

Returns:
List of anagrams that were found. Empty if
none were found.
"""
permutations = itertools.permutations(word, len(word))
possible = (''.join(x) for x in permutations)
found = {word for word in possible if word in dictionary}
return list(found)

写docstring的时候,需要注意下面几种特殊的情况:

  • 如果函数没有参数,而且返回的是个比较简单的值,那么就不用按照上面讲的那种格式分段书写了。直接用一句话来描述整个函数可能会更好。
  • 如果函数没有返回值,那么最好是把描述返回值的那段完全省去,而不要专门写出返回None。
  • 如果函数所抛出的异常也是接口的一部分(参见第20条),那么应该在docstring里面详细解释每一种异常的含义,并说明函数在什么场合会抛出这样的异常。
  • 如果函数在正常使用的过程中,不会抛出异常,那么无须专门指出这一点。如果函数可以接受数量可变的位置参数(参见第22条或关键字参数(参见第23条),那么应该在解释参数的那一部分用*args与**kwargs来说明这两种参数的用途。
  • 如果参数有默认值,那么文档里应该提到这些默认值(参见第24条)。
  • 如果函数是个生成器(参见第30条),那么应该在docstring里面写明这个生成器在迭代过程中会产生什么样的值。
  • 如果函数是异步协程(参见第60条),那么应该在docstring里面解释这个协程执行到何时会暂停。

用类型注解来简化docstring

1
2
3
4
5
from typing import Container, List

def find_anagrams(word: str,
dictionary: Container[str]) -> List[str]:
pass

类型信息最好只写在类型注解或docstring其中一个地方,防止修改时忘了改其中一个。

第85条 用包来安排模块,以提供稳固的API

大多数情况下,把名为__init__.py的空白文件放在某个目录中,即可令该目录成为一个包。

用包划分名称空间

包的一个用途是帮助把模块安排到不同的名称空间里,这样即使两个文件同名,也能够区分。

引入的函数或模块同名的话,应该加上as子句:

1
2
from analysis.utils import inspect as analysis_inspect
from frontend.utils import inspect as frontend_inspect

或者:

1
2
3
4
5
6
7
import analysis.utils
import frontend.utils

value = 33
if (analysis.utils.inspect(value) ==
frontend.utils.inspect(value)):
print('Inspection equal!')

通过包来构建稳固的API

要想API的功能稳定,必须隐藏软件包内部的代码结构,不要让外部的开发者依赖这套结构。

Python允许我们通过__all__这个特殊的属性,决定模块或包里面有哪些内容应该当做API公布到外界。

1
2
3
4
5
6
7
# model.py
__all__ = ['Projectile']

class Projectile:
def __init__(self, mass, velocity):
self.mass = mass
self.velocity = velocity
1
2
3
4
5
6
7
8
9
10
11
12
# utils.py
from . models import Projectile

__all__ = ['simulate_collision']

def _dot_product(a, b):
pass

def simulate_collision(a, b):
after_a = Projectile(-a.mass, -a.velocity)
after_b = Projectile(-b.mass, -b.velocity)
return after_a, after_b
1
2
3
4
5
6
# __init__.py
__all__ = []
from . models import *
__all__ += models.__all__
from . utils import *
__all__ += utils.__all__

尽量不要用import * 的形式!!

第86条 考虑用模块级别的代码配置不同的部署环境

生产环境/开发环境切换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# dev_main.py
TESTING = True

import db_connection

db = db_connection.Database()

# prod_main.py
TESTING = True

import db_connection

db = db_connection.Database()

# db_connection.py
import __main__

class TestingDatabase:
pass

class RealDatabase:
pass

if __main__.TESTING:
Database = TestingDatabase
else:
Database = RealDatabase

第87条 为自编的模块定义根异常,让调用者能够专门处理与此API有关的异常

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Error(Exception):
"""Base-class for all exceptions raised by this module."""

class InvalidDensityError(Error):
"""There was a problem with a provided density value."""

class InvalidVolumeError(Error):
"""There was a problem with the provided weight value."""

def determine_weight(volume, density):
if density < 0:
raise InvalidDensityError('Density must be positive')
if volume < 0:
raise InvalidVolumeError('Volume must be positive')
if volume == 0:
density / volume

有了这样的根异常,调用这个API的开发者就可以通过它把所有相关的错误全部捕获下来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class my_module:
Error = Error
InvalidDensityError = InvalidDensityError

@staticmethod
def determine_weight(volume, density):
if density < 0:
raise InvalidDensityError('Density must be positive')
if volume < 0:
raise InvalidVolumeError('Volume must be positive')
if volume == 0:
density / volume

try:
weight = my_module.determine_weight(1, -1)
except my_module.Error:
logging.exception('Unexpected error')
else:
assert False

定义根异常有三大好处:

  • 让调用者能够注意自己在使用API时出现的疏忽。

  • 有助于发现API本身的bug。写API时只抛出继承自根异常的错误。如果发生了其他错误,说明可能API实现里有bug。

  • 让开发者以后能够平稳地更新API。

这种思路还可以继续推广:在根异常下创建几个小的门类,让每个门类都有自己的根异常。

第88条 用适当的方式打破循环依赖关系

与其他人合作的时候,难免会遇上两个模块相互依赖的情况:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# dialog.py
import app

class Dialog:
def __init__(self, save_dir):
self.save_dir = save_dir

save_dialog = Dialog(app.prefs.get('save_dir'))

def show():
print('Showing the dialog!')

# app.py
import dialog

class Prefs:
def get(self, name):
pass

prefs = Prefs()
dialog.show()

# main.py
import app

最好的解法是重构代码,把prefs数据结构放到依赖体系最底层。

除了这种解法以外,还有三个办法也可以解除循环依赖关系:

调整import语句的位置

1
2
3
4
5
6
7
8
9
# app.py
class Prefs:
def get(self, name):
pass

prefs = Prefs()

import dialog # Moved
dialog.show()

但这违背了PEP8,并且无法保证这份源文件里所有代码都能用这个导入模块。因此不推荐这种办法。

把模块分成引入-配置-运行三个环节

尽量缩减引入时所要执行的操作,让模块只把函数、类与常量定义出来,而不真的去执行操作,这样的话,Python程序在引入本模块的时候,就不会由于操作其他模块而出错了。我们可以把本模块里面,需要用到其他模块的那种操作放在configure函数中,等到本模块彻底引入完毕后再去调用。configure函数会访问其他模块中的相关属性,以便将本模块的状态配置好。这个函数是在该模块与它所要使用的那个模块都已经彻底引人后才调用的(也就是说,这两个模块都把各自的第5步执行完了),因此,其中涉及的所有属性全都定义过了。

下面,我们就按照这个思路改写dialog模块,让它不要刚一上来就访问prefs对象,而是待configure函数被调用时,再去访问。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# dialog.py
import app

class Dialog:
def __init__(self):
pass

save_dialog = Dialog()

def show():
print('Showing the dialog!')

def configure():
save_dialog.save_dir = app.prefs.get('save_dir')

# app.py
import dialog

class Prefs:
def get(self, name):
pass

prefs = Prefs()

def configure():
pass

# main.py
import app
import dialog

app.configure()
dialog.configure()

dialog.show()

动态引入

第三个办法比前两个都简单,也就是把import语句从模块级别下移到函数或方法里面,这样就可以解除循环依赖关系了。这种import语句并不会在程序启动并初始化本模块时执行,而是等到相关函数真正运行的时候才得以触发,因此又叫作动态引入(dynamic import)。

下面,我们就用动态引人的办法修改dialog模块。这次,它只会在dialog.show函数真正运行的时候去引入import模块,而不像原来那样,模块刚一初始化,就要引入app。

1
2
3
4
5
6
7
8
9
10
11
12
13
# dialog.py
class Dialog:
def __init__(self):
pass

# Using this instead will break things
# save_dialog = Dialog(app.prefs.get('save_dir'))
save_dialog = Dialog()

def show():
import app # Dynamic import
save_dialog.save_dir = app.prefs.get('save_dir')
print('Showing the dialog!')

这样写,实际上与刚才那种先引入、再配置,然后运行的办法,是类似的。区别仅仅在于,这次不调整代码的结构,也不修改模块的定义与引人方式,只是把形成循环依赖的那条import语句,推迟到真正需要使用另外一个模块的那一刻。那时,自然可以确信本模块所依赖的那个模块肯定已经初始化过了(也就是说,那个模块的第5步肯
定已经执行完了)。

当然了,一般来说,还是应该尽量避免动态引入,因为import语句毕竟是有开销的,如果它出现在需要频繁执行的循环体里面,那么这种开销会更大。另外,由于动态引入会推迟代码的执行时机,有可能让你的程序在启动了很久之后,突然因为在动态引入其他模块的过程中发生SyntaxError等错误而崩溃(如何避免此类问题,请参见第76
条)。

第89条 重构时考虑通过warnings提醒开发者API已经发生变化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import warnings

CONVERSIONS = {
'mph': 1.60934 / 3600 * 1000, # m/s
'hours': 3600, # seconds
'miles': 1.60934 * 1000, # m
'meters': 1, # m
'm/s': 1, # m
'seconds': 1, # s
}

def convert(value, units):
rate = CONVERSIONS[units]
return rate * value

def localize(value, units):
rate = CONVERSIONS[units]
return value / rate

def print_distance(speed, duration, *,
speed_units=None,
time_units=None,
distance_units=None):
if speed_units is None:
warnings.warn(
'speed_units required', DeprecationWarning)
speed_units = 'mph'

if time_units is None:
warnings.warn(
'time_units required', DeprecationWarning)
time_units = 'hours'

if distance_units is None:
warnings.warn(
'distance_units required', DeprecationWarning)
distance_units = 'miles'

norm_speed = convert(speed, speed_units)
norm_duration = convert(duration, time_units)
norm_distance = norm_speed * norm_duration
distance = localize(norm_distance, distance_units)
print(f'{distance} {distance_units}')

import contextlib
import io

fake_stderr = io.StringIO()
with contextlib.redirect_stderr(fake_stderr):
print_distance(1000, 3,
speed_units='meters',
time_units='seconds')

print(fake_stderr.getvalue())

>>>
1.8641182099494205 miles
C:\XXX.py:35: DeprecationWarning: distance_units required
warnings.warn(

warnings.warn函数提供了一个名为stacklevel的参数,让我们可以根据栈的深度指出真正触发这条警告的那个位置,而不是调用warnings.warn函数的字面位置。这项功能让我们可以把发出警告的这段逻辑封装成辅助函数,并通过这个辅助函数检查用户在调用print_distance时,有没有指定相关的参数,如果没有,就打印出调用print_distance的那行语句所在的位置。早前用来检查参数取值的那几个if结构,现在全都可以改由这样的辅助函数来实现。下面就定义辅助函数,如果用户没有明确给print_distance的某个参数传值,那么require函数会发出警告并且让该参数取默认值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def require(name, value, default):
if value is not None:
return value
warnings.warn(
f'{name} will be required soon, update your code',
DeprecationWarning,
stacklevel=3)
return default

def print_distance(speed, duration, *,
speed_units=None,
time_units=None,
distance_units=None):
speed_units = require('speed_units', speed_units, 'mph')
time_units = require('time_units', time_units, 'hours')
distance_units = require(
'distance_units', distance_units, 'miles')

norm_speed = convert(speed, speed_units)
norm_duration = convert(duration, time_units)
norm_distance = norm_speed * norm_duration
distance = localize(norm_distance, distance_units)
print(f'{distance} {distance_units}')
  • 设计新版API的时候,可以通过warnings模块把已经过时的用法通知到调用者,让他们看到消息后尽快改用新的写法,以防程序在我们彻底放弃旧版API之后崩溃。
  • 在命令行界面执行Python解释器的时候,可以开启 -W error选项,从而将警告视为错误。这在执行自动测试的过程中特别有用,因为这样可以及时发现受测程序所依赖的 API是否已经推出了新的版本。
  • 如果程序要部署到生产环境,那么可以通过logging模块将警告信息重定向到日志系统,把程序在运行过程中遇到的警告纳入现有的错误报告机制之中。
  • 如果你设计的 API会发出警告,那么应该为此编写测试,确保下游开发者在使用API的过程中,能够在适当的时机收到正确的警告信息。

第90条 考虑通过typing做静态分析,以消除 bug

文档可以很好地帮助用户理解API的正确用法(参见第84条),然而只有文档可能还不够,有时我们还是会把API用错,导致程序出现bug。所以,最好能有一套机制来验证用者使用API的方式是否正确,如果我们把自己的API发布出去,那么这套机制还能帮助其他开发者检查他们的代码有没有恰当地使用这套API。许多编程语言通过编译期的类检查来实现这种验证,这确实能够消除某些bug。

Python以前主要关注的是动态特性,所以没有提供编译期的类型安全机制。但是最近,Python开始引入一套特殊的写法,让我们可以通过内置的typing模块给变量、类中的字段、函数及方法添加类型信息。这些类型提示(type hint)信息可以实现渐进的类型判定机制(gradual typing),让我们在开发项目的过程中,把能够在编译期明确指定类型的地方逐渐确定下来。

给Python程序的代码添加类型信息之后,我们就可以运行静态分析(static analysis)工具,分析这些代码里面是否存在极有可能出现bug的地方。Python内置的typing模块身并不实现类型检查功能,它只是一套可以公开使用的代码库,其中定义了相关的类型(也包括泛型类型),我们可以用这些类型来注解Python代码,并利用其他工具根据这些类型判断受注解的代码有没有正确地得到使用。

Python解释器有许多种不同的实现方案,例如CPython、PyPy等,与之类似,typing模块相搭配的Python静态分析工具,也有很多方案。笔者编写本书的时候,比较流行的是mypy、pytype、pyright 与pyre。本书中typing范例,笔者打算用mypy来验证,而且验证时会带上–strict标志,以便将该工具所能判断的各种问题全都显示出来。下面这行命令,可以用mypy给example.py文件里的代码做静态分析。

1
python3 -m mypy --strict example.py

这些工具能够帮我们在运行程序之前,发现许多种常见的错误,除了把测试用例写好外(参见第76条),这样的工具会给代码多添加一层安全保障。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Counter:
def __init__(self) -> None:
self.value: int = 0 # Field / variable annotation

def add(self, offset: int) -> None:
value += offset # Oops: forgot "self."

def get(self) -> int:
self.value # Oops: forgot "return"

counter = Counter()
counter.add(5)
counter.add(3)
assert counter.get() == 8

>>>
python3 -m mypy --strict example.py
.../example.py: error: Name 'value' is not defined
.../example.py: error: Missing return statement

可以利用typing模块给函数所涉及的泛型做注解,从而通过静态手段把程序运行时可能发生的错误提前探查出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from typing import Callable, List, TypeVar

Value = TypeVar('Value')
Func = Callable[[Value, Value], Value]

def combine(func: Func[Value], values: List[Value]) -> Value:
assert len(values) > 0

result = values[0]
for next_value in values[1:]:
result = func(result, next_value)

return result

Real = TypeVar('Real', int, float)

def add(x: Real, y: Real) -> Real:
return x + y

inputs = [1, 2, 3, 4j] # Oops: included a complex number
result = combine(add, inputs)
assert result == 10