什么是Tail Call?
Tail Call(尾调用)指,我们调用一个函数后,立刻返回。
def triple(x): # 返回的不是函数
return x * 3
def is_tail_call(x): # 是尾调用,因为我们调用后就返回了
return triple(x)
def not_tail_call(x): # 不是尾调用,因为我们调用完还要1 + result
return 1 + triple(x)
def half_tail_call(x): # 外面的是尾调用,里面的不是
return triple(triple(x))
在以上例子中,is_tail_call的triple,跟half_tail_call的外面的triple,是tail call。
为什么Tail Call很重要?
因为Tail Call提供了很好的优化。
假设我们是register machine,return address存在一个register里面,如果我们tail call,我们需要:把argument放对位置,然后goto function就可以了!
如果是non tail call,我们需要把argument放对位置,保存旧的return address(比如说塞上stack),设定return address为下一行,goto function,把旧的return address restore掉,然后该干啥干啥。
这样,我们额外的多做了一push一pop,space & time overhead一下子上来了,怪不得60~70年代的人不肯用函数。。
当我们递归的时候,不做这个优化导致了big O上的空间差距 - 一个要不停的push stack,一个说,stack是啥?
另一点是,当递归的时候,argument的位置是对的,‘放对位置’这一步就没有了,于是就成为了一个goto self - 也就是loop。
尾递归优化,其实就是Tail Call Optimization在taill call self下的优化。(见Debunking the ‘Expensive Procedure Call’ Myth, or, Procedure Call Implementations Considered Harmful, or, Lambda: The Ultimate GOTO)
如果你会continuation,Tail Call说的就是,当你call进去的continuation,跟你自己的cont是eta equivalent时,直接传就好了。
Continuation某种程度下对应返回地址,所以这也是说,如果没有必要,不需要动返回地址。
但是:
很不幸的,有些解释器/编译器不做tail call optimization。这代表上面写递归函数,会占用大量空间。怎么办?
假设我们是C语言。
那很简单,直接inline asm啊!啥地方要tail call,直接inline一个goto过去,一切解决。
当然,inline asm没多少语言支持 - 连C也要开编译器扩展才行。
那,我们在程序里面模拟一个program counter,不就一切OK了吗?
# 转换:
# 在函数头插入global pc
# 所有尾调用变成pc = ...
# 改用eval
pc = None # program counter
def eval(f):
global pc
pc = f
while True:
pc()
cnt = 0
def naive_rec():
global cnt
print(cnt)
cnt += 1
naive_rec()
# naive_rec() # 错误:达到最深递归层次
cnt = 0
def rec():
global pc
global cnt
print(cnt)
cnt += 1
pc = rec
# eval(rec) # 远远超过了递归限制!
好像可以呢!
但是函数参数是不是要像register那样手动设置?那样实在太丑了。
幸好,python有closure,我们放进closure里面就行了。
# 转换:
# 尾调用变lambda
# eval加上lambda
pc = None # program counter
def eval(f):
global pc
pc = f
while True:
pc()
def naive_rec_arg(cnt):
print(cnt)
naive_rec_arg(cnt + 1)
# naive_rec_arg(0) # 毫无意外,同一错误
def rec_arg(cnt):
global pc
print(cnt)
pc = lambda: rec_arg(cnt + 1)
# eval(lambda: rec_arg(0)) # 再一次超越极限!
一切都好。但是,函数有输入,当然也有输出,现在怎么办?
要注意的是,我们不能用全局变量来模拟。
原因很简单,eval是个死循环,永远不返回,就算返回值assign到某全局变量,我们也不能让当前函数调用完,让调用者使用返回结果。
怎么办?
既然核心问题是eval,我们使得eval不是死循环就可以了啊!
规定:如果pc为None,则表明返回。
# 转换:
# 对尾调用,pc = lambda: ... 跟上 return None
# 对普通返回,pc = None 跟上 return ...
# 函数尾:pc = None
pc = None # program counter
def eval(f):
global pc
old_pc = pc # 如果eval里面用eval怎么办?保存老pc就OK
pc = f
result = None
while pc:
result = pc()
pc = old_pc
return result
def naive_is_even_0(x):
if x > 1:
return naive_is_even_1(x - 2)
if x == 1: # 故意这样写,测试early return正确性
return False
else:
return True # 通过互递归提高难度
def naive_is_even_1(x):
return naive_is_even_0(x)
assert not naive_is_even_0(123)
# assert not naive_is_even_0(2345) # 又是你
def is_even_0(x):
global pc
if x > 1:
pc = lambda: is_even_1(x - 2)
return None # 没有这行会接着运行!
if x == 1:
pc = None
return False
else:
pc = None
return True
pc = None # 这很二,但是我们希望转换越简单越好!
def is_even_1(x):
global pc
pc = lambda: is_even_0(x)
return None
pc = None
assert not eval(lambda: is_even_0(2345)) # 过!
很好。但是我们看一下,这代码很危险!
很简单,我们要先set pc,然后再return一些东西。
但是,如果我们手滑了一下,忘了set pc,会怎么样?
这样,我们就会进入死循环,而得益于我们的优化,我们甚至不能stack overflow来表示无限递归了!
我们也希望,pc set成lambda的时候,返回值是None:那时,返回值是被无视的,但是限制返回为None能更早暴露程序的问题。
我们希望,有一个python语句,SetFunAndReturnNone,来保证set pc后才能return。
同理,如果有个python语句,SetNoneAndReturnVal,保证set None后会返回一个值(没有返回的就返回None,表示没忘)也好,这样能确定程序员在用pc修改器的时候,不会忘记设返回值。
python没有这两个语句,但是还好,这两个表达式都是会立刻返回的。
既然eval已经在操控pc了,我们可以让他担任更多解释器的工作:我们可以定义代表这两语句的class,pc的返回值只能是这两语句(如果忘记设定,导致pc返回None,算作错误,因为这情况下pc也会忘记更新了)
# 转换:
# 去掉global pc
# pc = lambda:... 接 return None 改写成 return SetFunAndReturnNone(...)
# pc = None 接 return ... 改写成 return SetNoneAndReturnVal(...)
# 行尾插入SetNoneAndReturnVal(None)
class SetFunAndReturnNone:
def __init__(self, f):
self.f = f
class SetNoneAndReturnVal:
def __init__(self, x):
self.x = x
pc = None # program counter
def eval(f):
global pc
old_pc = pc # 如果eval里面用eval怎么办?保存老pc就OK
pc = f
result = None
while pc:
command = pc()
if isinstance(command, SetFunAndReturnNone):
pc = command.f
result = None
else:
assert isinstance(command, SetNoneAndReturnVal)
pc = None
result = command.x
pc = old_pc
return result
def is_even_0(x):
if x > 1:
return SetFunAndReturnNone(lambda: is_even_1(x - 2))
if x == 1:
return SetNoneAndReturnVal(False)
else:
return SetNoneAndReturnVal(True)
return SetNoneAndReturnVal(None)
def is_even_1(x):
return SetFunAndReturnNone(lambda: is_even_0(x))
return SetNoneAndReturnVal(None)
assert not eval(lambda: is_even_0(2345)) # 过!
我们现在架构的改动已经做完了。但是,还能优化下这个文件。
最基本的,pc只有eval在用了,能搬进去,也不用担心重入的问题了,因为每个eval的调用都会有自己的pc
def eval(f):
pc = f # program counter
result = None
while pc:
command = pc()
if isinstance(command, SetFunAndReturnNone):
pc = command.f
result = None
else:
assert isinstance(command, SetNoneAndReturnVal)
pc = None
result = command.x
return result
pc = None的时候能直接返回,不需要在while里面:
def eval(f):
pc = f # program counter
result = None
while pc:
command = pc()
if isinstance(command, SetFunAndReturnNone):
pc = command.f
result = None
else:
assert isinstance(command, SetNoneAndReturnVal)
return command.x
return result
result可以去掉了,因为只会assign成None
def eval(f):
pc = f # program counter
while pc:
command = pc()
if isinstance(command, SetFunAndReturnNone):
pc = command.f
else:
assert isinstance(command, SetNoneAndReturnVal)
return command.x
如果我们假设SetFunAndReturnNone.f不会是None,f也不是None,pc在循环中不会变成None。我们就能直接while True:
我们也假设f不会是None - 我们的所有代码都不会传个None进去
def eval(f):
pc = f # program counter
while True:
command = pc()
if isinstance(command, SetFunAndReturnNone):
pc = command.f
else:
assert isinstance(command, SetNoneAndReturnVal)
return command.x
这两个class名字好长。。。而且pc已经是内部实现了,我们不应该暴露出来。
我们重命名SetFunAndReturnNone做More(需要更多计算),同样的,重命名SetNoneAndReturnVal为Done(搞定)。
我们最后,再优化一下使用代码,去掉不需要的return。
class More:
def __init__(self, f):
self.f = f
class Done:
def __init__(self, x):
self.x = x
def eval(f):
pc = f # program counter
while True:
command = pc()
if isinstance(command, More):
pc = command.f
else:
assert isinstance(command, Done)
return command.x
def is_even_0(x):
if x > 1:
return More(lambda: is_even_1(x - 2))
if x == 1:
return Done(False)
else:
return Done(True)
def is_even_1(x):
return More(lambda: is_even_0(x))
assert not eval(lambda: is_even_0(2345)) # pass!
大功告成。
我们最后,跟小杰站着世界树之巅,找爸爸之余看看我们实现了什么:
当我们eval的时候,我们会先调用pc。
然后,我们会push一个stack frame。
当tail call出现的时候,与其去apply之(并且把这个stack frame也push进去),我们直接返回!
这样,我们当前的stack frame就会pop掉。
pop掉后,我们又立刻进入该 call,再push一个frame。
我们的frame数量,就会这样0-1-0-1-0-1,循环往复。故此,这个方法叫做trampoline。
如果你觉得trampoline很有趣,我推荐去看http://blog.higher-order.com/assets/trampolines.pdf,简单又漂亮的一篇paper。