/*
From A Discipline of Programming, p.65, an interesting binary-search
square-root algorithm; n is the number whose root we want.
p, q, r := 0, 1, n; while (q <= n) q := q * 4;
while (q != 1) { q := q/4; h := p + q; p := p / 2;
if (r >= h) p, r := p + q, r - h; }
Or laid out like a sane person instead of a madman:
low = 0;
shift = 1;
remaining = n;
while (shift <= remaining) shift *= 4;
while (shift != 1) {
shift /= 4;
adjustment = low + shift;
low /= 2;
if (remaining >= adjustment) {
low += shift;
remaining -= adjustment;
}
}
Interesting thing about this is, it calculates a square root
efficiently (O(log N) time) without multiplication or division.
The idea is that you have a pair of variables a and c such that the
square root is in [a, a + c), and you do a binary search on that
interval by cutting c in half and then possibly incrementing a by it.
p = a * c = low
q = c*c = shift
r = n - a*a = remaining
h = 2 * p + q = adjustment
So, first you start with a= 0 and c=1; then you increase c until it's
big enough to encompass the range; then you keep halving the size of
the range by dividing c by 2, which results in q being divided by 4
and p by 2; then you have to figure out whether you want the lower
half of the range or the upper half of it, which is what the final if
statement is about. At the time we encounter the if:
h = 2*p + q = 2*a*c + c*c.
So
r - h = n - a*a - 2*a*c - c * c, = n - (a + c)^2
So iff r >= h, then n >= (a+c)^2, which is to say, we want the
upper half of the range; it was not enough to decrease c, we must also
increase a by an increment of c. Now, a'+c' = a + c, so that guarantees we incremented it enough.
p + q = (a + c)*c
which is the right number for the new p, and I already solved for r - h.
A base-10 version is also included below.
Newtonâ€™s Method is about the same amount of code (11 lines of code
conventionally formatted, slightly less than the above; 99 characters
golfed, or 79 if you use a crappy initial estimate, compared to 90 for
this approach), and on a machine with a good divider, about 60%
faster; but the termination and correctness arguments are much more
difficult. (And it may be silly to talk about machines with good
dividers, since almost all of them also have square root hardware.)
*/
#include
#include
/* 2**30 sqrts in 5m6s on a 1.6GHz Atom: 3.5M/sec, about 456 cycles each */
static inline int
sqt(n){int p=0,q=1,h;while(q<=n)q*=4;while(q-1)q/=4,h=p+q,p/=2,n=b*b);}
int sqrt_base10(int n) {
int low = 0;
int shift = 1;
int remaining = n;
int adjustment;
/* These "underlying" variables are only used for assertions. */
int a = 0;
int c = 1;
while (shift <= remaining) {
shift *= 100;
c *= 10;
}
/* loop invariant: square root is in [a, c), where shift = c*c, low = a*c,
* and remaining = n - a*a.
*/
while (shift != 1) {
assert(shift == c*c);
assert(low == a*c);
assert(remaining == n - a*a);
assert(a*a <= n);
assert((a+c)*(a+c) > n);
shift /= 100;
low /= 10;
c /= 10;
adjustment = 2 * low + shift; /* 2*a*c + c*c */
assert(shift == c*c);
assert(low == a*c);
assert(remaining == n - a*a);
/* At this point we may have to increase a (by c) to get the
* square root back into the range. That involves recalculating
* the things that depend on a: low (needs to become (a+c)c,
* i.e. low+shift), remaining (needs to become n - a*a - 2*a*c -
* c*c, i.e. n - (a+c)**2), and adjustment. */
while (remaining >= adjustment) {
a += c;
low += shift;
remaining -= adjustment;
adjustment = 2 * low + shift;
assert(shift == c*c);
assert(low == a*c);
assert(remaining == n - a*a);
}
assert(a*a <= n);
assert((a+c)*(a+c) > n);
}
assert(c == 1);
assert(a*a <= n);
assert((a+c)*(a+c) > n);
return low;
}
int sqrt_base10_bare(int n) {
int low = 0;
int shift = 1;
int remaining = n;
int adjustment;
while (shift <= remaining) shift *= 100;
while (shift != 1) {
shift /= 100;
low /= 10;
adjustment = 2 * low + shift; /* 2*a*c + c*c */
while (remaining >= adjustment) {
low += shift;
remaining -= adjustment;
adjustment = 2 * low + shift;
}
}
return low;
}
volatile int jj;
int main() {
int ii, kk;
for (ii = 0; ii < 42949672; ii++) {
jj = sqt(ii);
t(ii);
assert(jj == p);
s(ii);
assert(jj == a);
if (jj*jj > ii) printf("too high: %d: %d > %d", jj, jj*jj, ii);
kk = sqrt_base10(ii);
if (kk != jj) printf("base10 version buggy: %d != %d at %d\n", kk, jj, ii);
kk = sqrt_base10_bare(ii);
if (kk != jj) printf("base10_bare version buggy: %d != %d at %d\n", kk, jj, ii);
jj++;
if (jj*jj <= ii) printf("too low: %d: %d <= %d", jj, jj*jj, ii);
if (ii % 16777216 == 16777215) printf("checked up to %d\n", ii);
}
return 0;
}