#!/usr/bin/python3 """Find optimal ways to multiply by small constant integers. The question is, suppose you can add and subtract quantities, and you want to compute a particular multiple of some base quantity x. And you can’t bitshift — maybe “addition” is really multiplication or convolution or function composition or something, not binary addition. What is the optimal way to multiply by some small constant? (This is based on a problem in Stepanov. In fact, it might be identical.) One solution is to compute the multiples of powers of 2 as high as you need, then add them together. For example, to compute 15x, you could start by computing 2x, 4x, and 8x, then 8x+4x+2x+x. This requires six additions. But it turns out that you can do it instead in if you compute it in five additions by computing 2x, 3x, 5x, 10x, and 15x. This program finds such optimal ways to compute multiples. Subtraction is helpful for computing 31x by computing 2x, 4x, 8x, 16x, 32x, and then 31x. Subtraction is not helpful for smaller positive multiples. Some of the sequences it finds include negative multiples, even for positive numbers. For example, you can compute 63x in 7 steps as 2x, 4x, 8x, 16x, 32x, -31x, and then 63x as 32x - -31x. But this sort of thing is totally unnecessary. This program is efficient enough to be practical for finding optimal operation sequences of up to at least 20 steps with no bitshift. There *is* a `-s` option to allow left bitshifts of up to some maximum size, treating it as zero-cost — which is true for cases like the i386 LEA instruction, the ARM operand bitshifts, and wiring up registers in logic. The simplest example of this is the difference of two powers of 2 — to multiply by 60, `optmul.py -o 60 8` finds this sequence: v1 = x # x v2 = v1 + v1 # 2x v3 = v2 + v1 # 3x v4 = v3 + v2 # 5x v5 = v4 + v4 # 10x v6 = v5 + v4 # 15x v7 = v6 + v6 # 30x v8 = v7 + v7 # 60x By contrast, `optmul.py -s 6 -o 60 8` finds this sequence: v1 = x v2 = (v1 << 6) - (v1 << 2) As another example, to multiply by 85 and allowing bitshifts of at least 4, this program finds this: v1 = x v2 = v1 + (v1 << 2) v3 = v2 + (v2 << 4) `optmul.py -s 4 2` finds 426 multipliers that are thus achievable in only two additions and subtractions with bitshifts of no more than 4, including all the integers in [-35, 41], and 90 of the 100 numbers in [0, 99]. (Except that it doesn’t show zero.) """ import argparse def describe_op(dest, left, lsh, op, right, rsh): return 'v%d = %s %s %s' % (dest, ('(v%d << %d)' % (left, lsh) if lsh else 'v%d' % left), op, ('(v%d << %d)' % (right, rsh) if rsh else 'v%d' % right)) def _sequences(n_steps, shift): if not n_steps: yield [(0, 'v0 = 0'), (1, 'v1 = x')] return already_found = set() for seq in sequences(n_steps-1, shift): for left, lsh, right, rsh in ((left, lsh, right, rsh) for left in range(len(seq)) for lsh in range(shift+1) for right in range(left+1) for rsh in range(shift+1)): result = (seq[left][0] << lsh) + (seq[right][0] << rsh) if (result not in already_found and result not in [pr for pr, *_ in seq]): desc = describe_op(len(seq), left, lsh, '+', right, rsh) yield seq + [(result, desc)] already_found.add(result) for left, lsh, right, rsh in ((left, lsh, right, rsh) for left in range(len(seq)) for lsh in range(shift+1) for right in range(len(seq)) for rsh in range(shift+1)): result = (seq[left][0] << lsh) - (seq[right][0] << rsh) if (result not in already_found and result not in [pr for pr, *_ in seq]): desc = describe_op(len(seq), left, lsh, '-', right, rsh) yield seq + [(result, desc)] already_found.add(result) seq_cache = {} def sequences(n_steps, shift): if (n_steps, shift) not in seq_cache: seq_cache[n_steps, shift] = list(_sequences(n_steps, shift)) return seq_cache[n_steps, shift] def optimal(n, shift=0): best = {} for n_steps in range(n+1): for sequence in sequences(n_steps, shift): result = sequence[-1][0] if result not in best: best[result] = sequence return best def main(): formatter = argparse.RawDescriptionHelpFormatter p = argparse.ArgumentParser(description=__doc__, formatter_class=formatter) p.add_argument('n', type=int, help='the number of steps up to which to compute') p.add_argument('-o', '--only', type=int, help='only search for the optimal way to compute a particular multiple') p.add_argument('-s', '--shift', type=int, default=0, help='allow bitwise left shifts of up to SHIFT bits') args = p.parse_args() if args.only is not None: for n in range(args.n+1): best = optimal(n, shift=args.shift) if args.only in best: print(best[args.only]) break else: best = optimal(args.n, shift=args.shift) for i, method in sorted(best.items()): print(i, len(method), method) if __name__ == '__main__': main()