This is the second entry in an intermittent series of posts on how to optimize some basic mathematical functions in Haskell. The previous entry is here.
Consider the following problem: A person is ascending a staircase with \(n\) steps. The ascendant is sufficiently tall and the steps sufficiently shallow that with each gait, the ascendant can take one, two, or three steps. How many different ways are there for the ascendant to reach the top?
For reasons that will soon become obvious, let's call this value the tribonacci number of \(n\). One simple way I have seen this problem solved is:
tribonacci1 :: Integer -> Integer
tribonacci1 n
| n == 0 = 1
| n < 0 = 0
| otherwise = tribonacci1 (n-1) + tribonacci1 (n-2) + tribonacci1 (n-3)
This code has some virtues: It is a very simple recursion translated almost directly from problem description; when \(n=0\), there is exactly one sequence of steps that leads to the top, the empty sequence; when negative, there is of course no sequence of positive steps that lead to the top; when positive, it is just the sum of the ways you could reach the top after taking a single, double, and triple step, respectively.
This code also reveals why the function is called tribonacci. It is just a variation of the well-known fibonacci function which adds the last three numbers of the series rather than the last two as in the fibonacci function.
The code also has a glaring fault: Because the evaluation for each \(n\) requires the evaluation of the function for three more \(n\)s, its performance is exponential in \(n\). This renders it much too slow for all but the smallest \(n\) on even the fastest computers. Indeed, on my workstation and even after compilation with GHC 7.10.2, calculating the value for \(n=30\) takes over 4 seconds. For \(n=35\), it was 83 seconds. For any larger \(n\), I lacked the patience.
One obvious common optimization would be to memoize the function. That way the function is only evaluated once for each \(n\) and the performance becomes linear in \(n\).
But let's try something slightly different, though functionally equivalent. Instead of recursing to tribonacci \(n-1\), tribonacci \(n-2\), and tribonacci \(n-3\), let's define an alternative function, let's call it \(a\) which takes a triple of numbers (tribonacci \(n-1\), tribonacci \(n-2\), tribonacci \(n-3\)) and returns the triple (tribonacci \(n\), tribonacci \(n-1\), tribonacci \(n-2\)). This function just needs to be applied iteratively and the value of tribonacci \(n\) can easily be extracted from the \(n\)-th iteration.
tribonacci2 :: Integer -> Integer
tribonacci2 n = (\(tn,_,_) -> tn) $ (iterate a (1,0,0))!!(fromIntegral n)
where
a :: (Integer, Integer, Integer) -> (Integer, Integer, Integer)
a (tnm1,tnm2,tnm3)=(tnm1+tnm2+tnm3,tnm1,tnm2)
This is much better. It evaluates for \(n=35\) instantly, even without compilation. Compiled, it evaluates for \(n=10^5\) (a 26,465-digit number!) in a second. Apart from addition for very large number ceasing to be a constant time operation, the performance of this code is linear. If you require all the tribonaccis up to \(n\), you will need at least \(O(n)\) operations, so in that sense tribonacci2 is asymptotically optimal.
But what if you only required the final result? Then we can do better. To see how, examine the function \(a\) again:
$$(t_{n}, t_{n-1}, t_{n-2}) = a (t_{n-1},t_{n-2},t_{n-3}) = (t_{n-1} + t_{n-2} + t_{n-3}, t_{n-1}, t_{n-2})$$Each of its outputs is just a linear combination of its inputs. Hence, if we represent each tuple of 3 sequential tribonaccis as a 3-dimensional vector, the application of \(a\) is just a multiplication by a constant matrix:
$$\hat{t}_{n} = \hat{a} \cdot \hat{t}_{n-1}, \hat{t}_{n} = \begin{pmatrix}t_n \\ t_{n-1} \\ t_{n-2} \end{pmatrix}, \hat{a} = \begin{pmatrix}1 && 1 && 1 \\ 1 && 0 && 0 \\ 0 && 1 && 0\end{pmatrix}$$Thanks to the associativity of matrix multiplication, this means that:
$$\hat{t}_{n} = \hat{a}^{n} \cdot \hat{t}_0$$We can code this as follows:
import Data.Array
tribonacci3 :: Integer -> Integer
tribonacci3 n = (matPow n a)!(0,0)
where
a :: Array (Int,Int) Integer
a = listArray ((0,0),(2,2)) [1,1,1,1,0,0,0,1,0]
matMul :: Array (Int,Int) Integer -> Array (Int,Int) Integer -> Array (Int,Int) Integer
matMul x y = array ((fst bs,fst be),(snd bs,snd be))
[((i,j),sum [(x!(i,k))*(y!(k,j))|k<-range bi])|i<-range bs,j<-range be]
where
bs = (fst (fst (bounds x)),fst (snd (bounds x)))
be = (snd (fst (bounds y)),snd (snd (bounds y)))
bi = (snd (fst (bounds x)),snd (snd (bounds x)))
bj = (fst (fst (bounds y)),fst (snd (bounds y)))
matPow :: Integer -> Array (Int,Int) Integer -> Array (Int,Int) Integer
matPow k a
| k==1 = a
| even k = m
| otherwise = matMul a m
where
b = matPow (div k 2) a
m = matMul b b
This looks like a lot more code, but matPow and matMul are just standard matrix multiplication and exponentiation function, which you probably have laying around in a library anyway. The rest of the code is not longer than the earlier versions.
This code runs in \(O(log(n))\) because matPow uses binary exponentiation. For \(n=10^5\), where tribonacci2 needed a second, tribonacci3 needs less than a millisecond. For larger \(n\), where tribonacci2 starts to falter, tribonacci3 keeps turning out ever more gigantic figures. The 264,650-digit result for \(n=10^6\) is calculated in 78 milliseconds; the 2,646,495 digits for \(n=10^7\), in less than 1.5 seconds. The supra-logarithmic slow-down is, again, due to the fact that, for sufficiently large integers, integer multiplication ceases to be a constant time operation.
Note that there is a pattern in the length of the tribonacci number for \(n\). The number of decimal digits for \(n\) seems to approach \(\approx 0.264649 n\), the number of binary digits \(\approx 0.879146 n\).
In our optimizations we went from exponential to linear to (near-)logarithmic performance. That is pretty good.
PS: A brief coda to this post, stating and deriving a closed functional form for tribonacci of \(n\) can be found here.