#!/usr/bin/python3 # -*- coding: utf-8 -*- """Karatsuba multiplication exercise Suppose you want to calculate some product: (aX + b)(cX + d) = acX² + (ad + bc)X + bd Here X is some power of the base of your number system, and this is the conventional algorithm for multiple-precision multiplication. This divides the problem of multiplying two numbers “ab” and “cd” into the problem of multiplying four pairs of numbers, each half as long; so it’s a sort of recursive divide-and-conquer algorithm which, in the end, takes O(N²) time: for 2ⁱ digits, you do i levels of divide-and-conquer, producing 4ⁱ bottom-level multiplications, which is just the square of the number of digits. These multiplications are then combined in a smaller number of shifted addition operations. Karatsuba came up with a different way to do this, computing (a + b)(c + d) = ac + bc + ad + bd. This contains the ad + bc sum we need as a couple of subterms. If we compute ac and bd, we can subtract them off to get ad + bc. For example, 93 × 24: ac = 9×2 = 18; bc = 3×4 = 12; (a + b)(c +d) = (9+3)(2+4) = 12 × 6 = 72; 72 - 18 - 12 = 42. So our final result is 1800 + 420 + 12 = 2232, which is correct. This has the advantage that, although the operations per internal node are slightly more complicated, instead of 4ⁱ bottom-level multiplications you have 3ⁱ. So, for example, if you have a 1,048,576-digit number, you need 1,099,511,627,776 bottom-level multiplications with the conventional algorithm, but only 3,486,784,401 with Karatsuba’s algorithm, which is about 0.3% of the number needed by the conventional algorithm. So here’s a simple arbitrary-precision positive integer library implemented in pure Python, just enough to enable me to write Karatsuba multiplication. To do the example computation above: >>> mul([3, 9], [4, 2], 10) [2, 3, 2, 2] Normally on a computer you use a base much larger than 10. """ from __future__ import print_function def normalize(n): "Remove most significant zeroes from a digit list." n = list(n) while n and n[-1] == 0: n.pop() return n def add(a, b, x): """Compute the sum of little-endian digit lists a and b in base x. So in, for example, add(a=[3, 9], b=[4, 2], x=10), the numbers being represented are 93 and 24, and the result [7, 1, 1] represents 117. But in add(a=[3, 9], b=[4, 2], x=12), the numbers being represented are 111 and 28, and the result [7, 11] represents 139. """ sum = [] carry = 0 for i in range(max(len(a), len(b))): ai = a[i] if i < len(a) else 0 bi = b[i] if i < len(b) else 0 carry, m = divmod(ai + bi + carry, x) sum.append(m) sum.append(carry) return normalize(sum) def sub(a, b, x): """Compute the difference of little-endian digit lists a and b in base x. Raises a ValueError if the result is negative. """ diff = [] borrow = 0 for i in range(max(len(a), len(b))): ai = a[i] if i < len(a) else 0 bi = b[i] if i < len(b) else 0 di = ai - bi - borrow if di >= 0: diff.append(di) borrow = 0 else: diff.append(di + x) borrow = 1 if borrow: raise ValueError(a, b, x) return normalize(diff) def shift(a, n): "Multiply a digit list a by power n of its base." return [0] * n + a def mul(a, b, x): """Compute the product of little-endian digit lists a and b in base x. This implements Karatsuba multiplication. """ print("multiplying", a, b) a, b = list(a), list(b) while len(a) < len(b): a.append(0) while len(b) < len(a): b.append(0) if len(a) == 0: # 0 · 0 = 0 return [] if len(a) == 1: carry, m = divmod(a[0] * b[0], x) return normalize([m, carry]) split = len(a) // 2 B, A = a[:split], a[split:] D, C = b[:split], b[split:] ac = mul(A, C, x) bd = mul(B, D, x) mt = sub(sub(mul(add(A, B, x), add(C, D, x), x), ac, x), bd, x) return add(add(bd, shift(mt, split), x), shift(ac, 2*split), x) def encode(n, x): "Encode number n as a little-endian digit list in base X." assert x >= 0 digits = [] while True: n, d = divmod(n, x) digits.append(d) if not n: return digits def decode(n, x): "Decode a little-endian digit list in base X." return sum(x**i * ni for i, ni in enumerate(n)) def test(): "Generate random multiplication problems and check them against Python." import random for i in range(1000): a, b = random.randrange(10000000), random.randrange(10000000) x = random.randrange(2, 61) ad, bd = encode(a, x), encode(b, x) prod = mul(ad, bd, x) diff = decode(prod, x) - a * b if diff: print("FAIL", ad, bd, diff) if __name__ == '__main__': test()