什么是Tail Call?

Tail Call(尾调用)指,我们调用一个函数后,立刻返回。

def triple(x): # 返回的不是函数
    return x * 3

def is_tail_call(x): # 是尾调用,因为我们调用后就返回了
    return triple(x)

def not_tail_call(x): # 不是尾调用,因为我们调用完还要1 + result
    return 1 + triple(x)

def half_tail_call(x): # 外面的是尾调用,里面的不是
    return triple(triple(x))

在以上例子中,is_tail_call的triple,跟half_tail_call的外面的triple,是tail call。

为什么Tail Call很重要?

因为Tail Call提供了很好的优化。

假设我们是register machine,return address存在一个register里面,如果我们tail call,我们需要:把argument放对位置,然后goto function就可以了!

如果是non tail call,我们需要把argument放对位置,保存旧的return address(比如说塞上stack),设定return address为下一行,goto function,把旧的return address restore掉,然后该干啥干啥。

这样,我们额外的多做了一push一pop,space & time overhead一下子上来了,怪不得60~70年代的人不肯用函数。。

当我们递归的时候,不做这个优化导致了big O上的空间差距 - 一个要不停的push stack,一个说,stack是啥?

另一点是,当递归的时候,argument的位置是对的,‘放对位置’这一步就没有了,于是就成为了一个goto self - 也就是loop。

尾递归优化,其实就是Tail Call Optimization在taill call self下的优化。(见Debunking the ‘Expensive Procedure Call’ Myth, or, Procedure Call Implementations Considered Harmful, or, Lambda: The Ultimate GOTO

如果你会continuation,Tail Call说的就是,当你call进去的continuation,跟你自己的cont是eta equivalent时,直接传就好了。

Continuation某种程度下对应返回地址,所以这也是说,如果没有必要,不需要动返回地址。

但是:

很不幸的,有些解释器/编译器不做tail call optimization。这代表上面写递归函数,会占用大量空间。怎么办?

假设我们是C语言。

那很简单,直接inline asm啊!啥地方要tail call,直接inline一个goto过去,一切解决。

当然,inline asm没多少语言支持 - 连C也要开编译器扩展才行。

那,我们在程序里面模拟一个program counter,不就一切OK了吗?

# 转换:
# 在函数头插入global pc
# 所有尾调用变成pc = ...
# 改用eval

pc = None # program counter

def eval(f):
    global pc
    pc = f
    while True:
        pc()

cnt = 0
def naive_rec():
    global cnt
    print(cnt)
    cnt += 1
    naive_rec()
    
# naive_rec() # 错误:达到最深递归层次

cnt = 0
def rec():
    global pc
    global cnt
    print(cnt)
    cnt += 1
    pc = rec

# eval(rec) # 远远超过了递归限制!

好像可以呢!

但是函数参数是不是要像register那样手动设置?那样实在太丑了。

幸好,python有closure,我们放进closure里面就行了。

# 转换:
# 尾调用变lambda
# eval加上lambda

pc = None # program counter

def eval(f):
    global pc
    pc = f
    while True:
        pc()

def naive_rec_arg(cnt):
    print(cnt)
    naive_rec_arg(cnt + 1)

# naive_rec_arg(0) # 毫无意外,同一错误

def rec_arg(cnt):
    global pc
    print(cnt)
    pc = lambda: rec_arg(cnt + 1)

# eval(lambda: rec_arg(0)) # 再一次超越极限!

一切都好。但是,函数有输入,当然也有输出,现在怎么办?

要注意的是,我们不能用全局变量来模拟。

原因很简单,eval是个死循环,永远不返回,就算返回值assign到某全局变量,我们也不能让当前函数调用完,让调用者使用返回结果。

怎么办?

既然核心问题是eval,我们使得eval不是死循环就可以了啊!

规定:如果pc为None,则表明返回。

# 转换:
# 对尾调用,pc = lambda: ... 跟上 return None
# 对普通返回,pc = None 跟上 return ...
# 函数尾:pc = None

pc = None # program counter

def eval(f):
    global pc
    old_pc = pc # 如果eval里面用eval怎么办?保存老pc就OK
    pc = f
    result = None
    while pc:
        result = pc()
    pc = old_pc
    return result

def naive_is_even_0(x):
    if x > 1:
        return naive_is_even_1(x - 2)
    if x == 1: # 故意这样写,测试early return正确性
        return False
    else:
        return True # 通过互递归提高难度

def naive_is_even_1(x):
    return naive_is_even_0(x)

assert not naive_is_even_0(123)
# assert not naive_is_even_0(2345) # 又是你

def is_even_0(x):
    global pc
    if x > 1:
        pc = lambda: is_even_1(x - 2)
        return None # 没有这行会接着运行!
    if x == 1:
        pc = None
        return False
    else:
        pc = None
        return True
    pc = None # 这很二,但是我们希望转换越简单越好!

def is_even_1(x):
    global pc
    pc = lambda: is_even_0(x)
    return None
    pc = None

assert not eval(lambda: is_even_0(2345)) # 过!

很好。但是我们看一下,这代码很危险!

很简单,我们要先set pc,然后再return一些东西。

但是,如果我们手滑了一下,忘了set pc,会怎么样?

这样,我们就会进入死循环,而得益于我们的优化,我们甚至不能stack overflow来表示无限递归了!

我们也希望,pc set成lambda的时候,返回值是None:那时,返回值是被无视的,但是限制返回为None能更早暴露程序的问题。

我们希望,有一个python语句,SetFunAndReturnNone,来保证set pc后才能return。

同理,如果有个python语句,SetNoneAndReturnVal,保证set None后会返回一个值(没有返回的就返回None,表示没忘)也好,这样能确定程序员在用pc修改器的时候,不会忘记设返回值。

python没有这两个语句,但是还好,这两个表达式都是会立刻返回的。

既然eval已经在操控pc了,我们可以让他担任更多解释器的工作:我们可以定义代表这两语句的class,pc的返回值只能是这两语句(如果忘记设定,导致pc返回None,算作错误,因为这情况下pc也会忘记更新了)

# 转换:
# 去掉global pc
# pc = lambda:... 接 return None 改写成 return SetFunAndReturnNone(...)
# pc = None 接 return ... 改写成 return SetNoneAndReturnVal(...)
# 行尾插入SetNoneAndReturnVal(None)

class SetFunAndReturnNone:
    def __init__(self, f):
        self.f = f

class SetNoneAndReturnVal:
    def __init__(self, x):
        self.x = x

pc = None # program counter

def eval(f):
    global pc
    old_pc = pc # 如果eval里面用eval怎么办?保存老pc就OK
    pc = f
    result = None
    while pc:
        command = pc()
        if isinstance(command, SetFunAndReturnNone):
            pc = command.f
            result = None
        else:
            assert isinstance(command, SetNoneAndReturnVal)
            pc = None
            result = command.x
    pc = old_pc
    return result

def is_even_0(x):
    if x > 1:
        return SetFunAndReturnNone(lambda: is_even_1(x - 2))
    if x == 1:
        return SetNoneAndReturnVal(False)
    else:
        return SetNoneAndReturnVal(True)
    return SetNoneAndReturnVal(None)

def is_even_1(x):
    return SetFunAndReturnNone(lambda: is_even_0(x))
    return SetNoneAndReturnVal(None)

assert not eval(lambda: is_even_0(2345)) # 过!

我们现在架构的改动已经做完了。但是,还能优化下这个文件。

最基本的,pc只有eval在用了,能搬进去,也不用担心重入的问题了,因为每个eval的调用都会有自己的pc

def eval(f):
    pc = f # program counter
    result = None
    while pc:
        command = pc()
        if isinstance(command, SetFunAndReturnNone):
            pc = command.f
            result = None
        else:
            assert isinstance(command, SetNoneAndReturnVal)
            pc = None
            result = command.x
    return result

pc = None的时候能直接返回,不需要在while里面:

def eval(f):
    pc = f # program counter
    result = None
    while pc:
        command = pc()
        if isinstance(command, SetFunAndReturnNone):
            pc = command.f
            result = None
        else:
            assert isinstance(command, SetNoneAndReturnVal)
            return command.x
    return result

result可以去掉了,因为只会assign成None

def eval(f):
    pc = f # program counter
    while pc:
        command = pc()
        if isinstance(command, SetFunAndReturnNone):
            pc = command.f
        else:
            assert isinstance(command, SetNoneAndReturnVal)
            return command.x

如果我们假设SetFunAndReturnNone.f不会是None,f也不是None,pc在循环中不会变成None。我们就能直接while True:

我们也假设f不会是None - 我们的所有代码都不会传个None进去

def eval(f):
    pc = f # program counter
    while True:
        command = pc()
        if isinstance(command, SetFunAndReturnNone):
            pc = command.f
        else:
            assert isinstance(command, SetNoneAndReturnVal)
            return command.x

这两个class名字好长。。。而且pc已经是内部实现了,我们不应该暴露出来。

我们重命名SetFunAndReturnNone做More(需要更多计算),同样的,重命名SetNoneAndReturnVal为Done(搞定)。

我们最后,再优化一下使用代码,去掉不需要的return。

class More:
    def __init__(self, f):
        self.f = f

class Done:
    def __init__(self, x):
        self.x = x

def eval(f):
    pc = f # program counter
    while True:
        command = pc()
        if isinstance(command, More):
            pc = command.f
        else:
            assert isinstance(command, Done)
            return command.x

def is_even_0(x):
    if x > 1:
        return More(lambda: is_even_1(x - 2))
    if x == 1:
        return Done(False)
    else:
        return Done(True)

def is_even_1(x):
    return More(lambda: is_even_0(x))

assert not eval(lambda: is_even_0(2345)) # pass!

大功告成。

我们最后,跟小杰站着世界树之巅,找爸爸之余看看我们实现了什么:

当我们eval的时候,我们会先调用pc。

然后,我们会push一个stack frame。

当tail call出现的时候,与其去apply之(并且把这个stack frame也push进去),我们直接返回!

这样,我们当前的stack frame就会pop掉。

pop掉后,我们又立刻进入该 call,再push一个frame。

我们的frame数量,就会这样0-1-0-1-0-1,循环往复。故此,这个方法叫做trampoline。

如果你觉得trampoline很有趣,我推荐去看http://blog.higher-order.com/assets/trampolines.pdf,简单又漂亮的一篇paper。