我们先从一个很简单的算法,快速幂,开始看起吧:

let rec quick_mod x y m =
  if y = 0 then 1
  else if y mod 2 = 1 then (x * (quick_mod x (y - 1) m) mod m)
  else quick_mod ((x * x) mod m) (y / 2) m

let example_0 = quick_mod 10 1000 3

显而易见,10的n次方除3都会余1:

utop # example_0;;
- : int = 1

这个算法试图在y很大的时候,算(x^y) mod m。

很明显,这个算法只会进行log y次’*然后mod’的操作 - 我们把这个操作定义为⊕吧。

我们去深入思考一下 - 为什么本来要进行y次⊕,现在,我们只需要进行log y次⊕?

这是因为我们通过重排,复用了计算中相同的部分。

比如如果我们想进行4次⊕,一般来说,我们会写下

x ⊕ (x ⊕ (x ⊕ x))

但是,在上面的算法中,我们计算的是

let x2 = x  x in
x2  x2

这时候,通过对x2的复用,我们省掉了一个⊕。

但是,如果我们把这个表达式展开,我们会发现 - 我们现在计算的是(x ⊕ x) ⊕ (x ⊕ x)。

严格意义上说,我们在计算着另一个很接近的东东 - 我们把⊕的顺序重排了。

这是因为 - 只有通过重排后,我们才有复用空间。

不过恰巧,我们有a ⊕ (b ⊕ c) = (a ⊕ b) ⊕ c,所以快速幂是正确的。

但是注意一下 - 无论⊕是什么,只需要我们满足上述等式,我们就有这个上面的’快速幂'!

我们把这个更抽象的新算法实现实现下吧:

我们要先定义⊕,就把‘带⊕的东东’叫做monoid吧:

type 'a monoid = {
  zero : 'a;
  add : 'a -> 'a -> 'a;
}

要注意,我们要处理y = 0的情况,所以也需要一个zero元素。

我们额外需要满足zero ⊕ x = x = x ⊕ zero。

(*exponentiation by squaring*)
(*multiplication by doubling*)
let rec m_b_d (m : 'a monoid) (x : 'a) (y : int) =
  if y = 0 then m.zero
  else if y mod 2 = 1 then m.add x (m_b_d m x (y - 1))
  else m_b_d m (m.add x x) (y / 2)

现在我们的抽象算法长这个样子。这个算法的学名叫做exponentiation by squaring,但是因为我们认为我们是在对+进行抽象,就叫multiplication by doubling吧。

let int_add_monoid : int monoid = {
  zero = 0;
  add = ( + );
}

let int_mult_monoid : int monoid = {
  zero = 1;
  add = ( * );
}

let modular_mult_monoid (m : int) : int monoid = {
  zero = 1;
  add = fun x y -> (x * y) mod m
}

int上的+跟*都构成了一个monoid,而且原本的 ⊕ 也是monoid。

小心留意一下,*的zero是1而不是0,因为1 * x = x。

let example_1 = m_b_d int_add_monoid 5 7
let example_2 = m_b_d int_mult_monoid 5 7
let example_3 = m_b_d (modular_mult_monoid 3) 10 1000

我们可以随便找一些例子,验证一下计算结果。

utop # example_1;;
- : int = 35
utop # example_2;;
- : int = 78125
utop # example_3;;
- : int = 1

但是,这些例子,或多或少都有一些太无聊了。我们抽象以后,有意思的实例依然只有快速幂。

有没有一些有意思的场景,可以使用m_b_d的?

有很多!有一个著名的blog post,<A Very General Method of Computing Shortest Paths>,就描述了上述算法可以来找shortest path,transitive closure,解线性方程,跟找DFA的正则表达式等等工作。

但是,我们不该离题太远,就描述一个简单又通用的问题吧 - 给一个正方形矩阵X,计算X^n。

当我们写出这个以后,我们可以用之来解counting coin, fibonacci等linearly recurrence的函数,也可以用来快速的saturate一个random walk。事实上,google正是把整个互联网看成一个巨大的graph,在上面利用random walk来实现pagerank的。如何给出transition probability,计算出最后每个node的停留概率?正是计算X^n which is multiplication by doubling!

那我们就来实现矩阵乘法吧。要注意 - 矩阵乘法,既有+又有*,而这两都刚好是monoid。

我们跟前面抽象出monoid一样,把+ *抽象一下:

type 'a semiring = {
  zero : 'a;
  one : 'a;
  add : 'a -> 'a -> 'a;
  mult : 'a -> 'a -> 'a;
}

let semiring_from_monoid (add : 'a monoid) (mult : 'a monoid) : 'a semiring = {
  zero = add.zero;
  one = mult.zero;
  add = add.add;
  mult = mult.add;
}

let add_monoid_from_semiring (r : 'a semiring) : 'a monoid = {
  zero = r.zero;
  add = r.add;
}

let mult_monoid_from_semiring (r : 'a semiring) : 'a monoid = {
  zero = r.one;
  add = r.mult;
}

let int_semiring : int semiring = semiring_from_monoid int_add_monoid int_mult_monoid

从两个monoid我们可以得到semiring(如果符合结合律),也可以从semiring得到monoid,int也是个semiring。

现在,我们可以定义矩阵。矩阵是一个数组的数组。

type 'a matrix = {
  data : 'a array array;
  m : int;
  n : int;
}

我们可以拿到矩阵的行向量跟列向量,也可以实现点乘。

let row (mat : 'a matrix) (m : int) : 'a array = mat.data.(m)

let col (mat : 'a matrix) (n : int) : 'a array =
  Array.init (mat.m) (fun m -> mat.data.(m).(n))

let dot (r : 'a semiring) (x : 'a array) (y : 'a array) : 'a =
  Array.fold_left r.add r.zero (Array.map2 r.mult x y)

有了这些,我们就可以实现矩阵乘法了。

let build_matrix (f : int -> int -> 'a) (m : int) (n : int) : 'a matrix = {
  data = Array.init m (fun m -> Array.init n (fun n -> f m n));
  m = m;
  n = n;
}

let mm (r : 'a semiring) (x : 'a matrix) (y : 'a matrix) : 'a matrix = 
  build_matrix (fun m n -> dot r (row x m) (col y n)) x.m y.n

mm也是一个monoid,于是我们可以利用m_b_d求矩阵指数。

let mm_semiring (r : 'a semiring) (mn : int) : 'a matrix monoid = {
  zero = build_matrix (fun m n -> if m = n then r.one else r.zero) mn mn;
  add = mm r;
}

let example_4 : int matrix =
  let m = mm_semiring int_semiring 3 in
  m_b_d m m.zero 5

let example_5 : int matrix =
  let m = mm_semiring int_semiring 3 in
  let mat = build_matrix (fun _ _ -> 1) 3 3 in
  m_b_d m mat 5

单位矩阵自乘应该是单位矩阵,下面的则是在计算3的次方:

utop # example_4;;
- : int matrix =
{data = [|[|1; 0; 0|]; [|0; 1; 0|]; [|0; 0; 1|]|]; m = 3; n = 3}
utop # example_5;;
- : int matrix =
{data = [|[|81; 81; 81|]; [|81; 81; 81|]; [|81; 81; 81|]|]; m = 3; n = 3}

但是,我们这个实现在性能上有很多可以改进的地方:

0 - 用array of array来代表matrix的时候,相邻array不一定在内存上相邻,不能很好利用cache locality

1 - OCaml array会做bound checking,于是我们每个访问都会带来不小开销

2 - col会返回一个新array,带来allocation开销跟进一步降低cache locality

3 - 我们用到了大量的高阶函数,如果inlining失效,这些高阶函数将会在runtime存在,并且带来大量pipeline stalling

某种意义上,这是冯洛伊曼瓶颈的极致体现:我们只想做一些浮点的+跟*运算,但是为了做这些运算,我们需要far memory fetch,需要bound checking,需要allocator,需要closure!

所以,我们写出了一个很漂亮很通用的算法,但是代价是做出了很多性能上的妥协。我们如果去手写C甚至CUDA,这些问题可以避免,但是就跟上述的漂亮或者通用无关了。

世间安得双全法,不负如来不负卿。

但是如果我告诉你,双全法存在呢?

怎么搞?

我们可以向星际争霸选手学习。不,不是学习他们的眼瞎,而是学习他们对timing的把握。

我们要注意到,计算机中,不同数据的‘时效’是不一样的。

比如在pagerank里面,互联网千变万化,所以这个转移概率是每一瞬间都在更新。

但是,需要自乘多少次,大概是一个hard code的大number

至于矩阵大小呢?诚然,互联网上每时每刻都在增加新页面,但是我们可以首先预留一些空位,这样,矩阵大小也不需要经常变来变去。

所以,我们更早知道矩阵大小n,跟自乘次数y。

但是,这两个数,控制着所有的内存读写,跟所有的控制流!只要知道这两个数,我们就能定死控制流跟数据流,矩阵的值只会控制具体计算的数值。

那,我们只需要在知道这两个数的最早时刻,把控制跟数据存取做了,剩下的只有很多很多计算。然后,我们只需要每次重复做这些不含控制的运算就好了。

某种意义上很不可思议,是吧 - 如果告诉你矩阵前两位,同样是两个数,你什么都做不了。

那,我们开始吧。我们需要设计一个编程语言,这个编程语言只有运算而没控制,来表达‘优化后剩下的’:

type expr =
| Var of string
| Int of int
| Add of expr * expr
| Mult of expr * expr

我们只有常量,变量,跟+ *。

let expr_semiring : expr semiring = {
  zero = Int 0;
  one = Int 1;
  add = (fun x y -> Add (x, y));
  mult = (fun x y -> Mult (x, y));
}

let example_6 : expr = m_b_d (add_monoid_from_semiring expr_semiring) (Var "x") 13

这很trivially是个semiring。我们现在可以先从简单的例子开始,用+表示x * 13:

utop # example_6;;
- : expr/2 =
Add (Var "x",
 Add (Add (Add (Var "x", Var "x"), Add (Var "x", Var "x")),
  Add
   (Add (Add (Add (Var "x", Var "x"), Add (Var "x", Var "x")),
     Add (Add (Var "x", Var "x"), Add (Var "x", Var "x"))),
   Int 0)))

奇怪,怎么x还是出现了13次,重排是重排了,但是大O错了?

是这样的:在调用m_b_d的时候,m_b_d的确只用了log次+,但是,这颗AST的节点有很多的sharing,当我们把sharing展开的时候,就会指数爆炸。

从另一个角度看,我们忘了给语言加入‘绑定变量’,那从根本上就不可能用m_b_d的trick优化计算。

那,我们把let加进expr吧

type expr =
| Var of string
| Int of int
| Add of expr * expr
| Mult of expr * expr
| Let of string * expr * expr
| Array of expr array

无视一下Array,以后要用上。

为了避免刚刚的问题再次发生,我们可以限定:除了Var跟Int这两计算起来简单,AST上也由一个trivial的single node表达的expr,其他的一切expr,都是二等公民,不能传来传去(否则会不小心复制)!

看上去,这个限制很大,但是我们可以通过Let,把一切都绑定到变量上,然后变量就可以到处飞了。

我们可以用一个可变列表of binding name to bounded value,来记录下所有的绑定。

let counter : int ref = ref 0

let fresh () : string =
  let x = !counter in
  counter := x + 1;
  "x" ^ string_of_int x

type letlist = (string * expr) list ref

let new_letlist () : letlist = ref []

当我们有一个新的复杂表达式,我们可以放上这个列表,然后拿回一个变量:

let push_letlist (ll : letlist) (x : expr) : expr =
  let x_ = fresh() in
  ll := (x_, x) :: !ll;
  Var x_

我们用完letlist后,可以用Let变回一个AST:

let let_letlist (ll : letlist) (x : expr) : expr =
  List.fold_left (fun x (n, v) -> Let (n, v, x)) x !ll

我们也可以用著名的bracket trick来保证let_letlist不会被忘记使用,来防止变量变成unbounded variable:

let with_letlist (f : letlist -> expr) : expr =
  let ll = new_letlist () in
  let_letlist ll (f ll)

有了这些,我们就可以重新实现ring:

let expr_semiring (ll : letlist) : expr semiring = {
  zero = Int 0;
  one = Int 1; 
  add = (fun x y -> push_letlist ll (Add (x, y)));
  mult = (fun x y -> push_letlist ll (Mult (x, y)));
}

现在再试一下刚刚的问题:

let example_7 : expr =
  with_letlist (fun ll ->
    m_b_d (add_monoid_from_semiring (expr_semiring ll)) (Var "x") 13)

的确把问题修复了!

utop # example_7;;
- : expr =
Let ("x0", Add (Var "x", Var "x"),
 Let ("x1", Add (Var "x0", Var "x0"),
  Let ("x2", Add (Var "x1", Var "x1"),
   Let ("x3", Add (Var "x2", Int 0),
    Let ("x4", Add (Var "x1", Var "x3"),
     Let ("x5", Add (Var "x", Var "x4"), Var "x5"))))))

试一试矩阵乘法:

let example_8 : expr =
  with_letlist (fun ll ->
    let m = build_matrix (fun m n -> 
      Var ("v" ^ string_of_int m ^ string_of_int n)) 2 2 in
    let em = m_b_d (mm_semiring (expr_semiring ll) 2) m 2 in
    Array (Array.map (fun a -> Array a) em.data))

的确把所有的控制流跟数据存取都消除了!

utop # example_8;;
- : expr =
Let ("x6", Mult (Var "v00", Var "v00"),
 Let ("x7", Mult (Var "v01", Var "v10"),
  Let ("x8", Add (Int 0, Var "x6"),
   Let ("x9", Add (Var "x8", Var "x7"),
    Let ("x10", Mult (Var "v00", Var "v01"),
     Let ("x11", Mult (Var "v01", Var "v11"),
      Let ("x12", Add (Int 0, Var "x10"),
       Let ("x13", Add (Var "x12", Var "x11"),
        Let ("x14", Mult (Var "v10", Var "v00"),
         Let ("x15", Mult (Var "v11", Var "v10"),
          Let ("x16", Add (Int 0, Var "x14"),
           Let ("x17", Add (Var "x16", Var "x15"),
            Let ("x18", Mult (Var "v10", Var "v01"),
             Let ("x19", Mult (Var "v11", Var "v11"),
              Let ("x20", Add (Int 0, Var "x18"),
               Let ("x21", Add (Var "x20", Var "x19"),
                Let ("x22", Mult (Var "x9", Int 1),
                 Let ("x23", Mult (Var "x13", Int 0),
                  Let ("x24", Add (Int 0, Var "x22"),
                   Let ("x25", Add (Var "x24", Var "x23"),
                    Let ("x26", Mult (Var "x9", Int 0),
                     Let ("x27", Mult (Var "x13", Int 1),
                      Let ("x28", Add (Int 0, Var "x26"),
                       Let ("x29", Add (Var "x28", Var "x27"),
                        Let ("x30", Mult (Var "x17", Int 1),
                         Let ("x31", Mult (Var "x21", Int 0),
                          Let ("x32", Add (Int 0, Var "x30"),
                           Let ("x33", Add (Var "x32", Var "x31"),
                            Let ("x34", Mult (Var "x17", Int 0),
                             Let ("x35", Mult (Var "x21", Int 1),
                              Let ("x36", Add (Int 0, Var "x34"),
                               Let ("x37", Add (Var "x36", Var "x35"),
                                Array
                                 [|Array [|Var "x25"; Var "x29"|];
                                   Array [|Var "x33"; Var "x37"|]|]))))))))))))))))))))))))))))))))

至此,我们完成了一开始提到的优化目标:消灭所有pointer fetch跟消灭所有控制流。

在达成这个目标的同时,我们并没对算法的通用性进行任何妥协:事实上,我们的矩阵乘法的代码,跟m_b_d的代码,完全没有动!这是因为我们的抽象,某种意义上,是stage-oblivious的:我们不关心矩阵里面存的是表达式还是int,我们只关心上面有+跟*。

这时候,我们得到了一个解决性能问题的通用工具:staging。

0 - 我们实现一个stage-oblivious program。这个程序应该通过良好的抽象,而忘掉‘最基本的类型’。

1 - 我们分析程序各种数据的时间差。并且发现一个可以利用的时间区分,把数据分成已知跟未知。

2 - 对于已知的程序,我们用具体的值表示,而对于未知的程序,我们用AST表示。