Friday, September 26, 2014

Fastest bit counting

The best bit counting algorithm as far as I know is the one invented by folks at Stanford University, which is always O(1).

int bitcount(int n)
{
    int cnt = 0;

    n = n - ((n >> 1) & 0x55555555);
    n = (n & 0x33333333) + ((n >> 2) & 0x33333333);

    return (((n + (n >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}

For example, after I compile it to x86 Assembly:

    .cfi_startproc
    movl    4(%esp), %eax
    movl    %eax, %edx
    sarl    %edx
    andl    $1431655765, %edx
    subl    %edx, %eax
    movl    %eax, %edx
    sarl    $2, %eax
    andl    $858993459, %edx
    andl    $858993459, %eax
    addl    %edx, %eax
    movl    %eax, %edx
    sarl    $4, %edx
    addl    %edx, %eax
    andl    $252645135, %eax
    imull   $16843009, %eax, %eax
    sarl    $24, %eax
    ret
    .cfi_endproc


But, how does actually the algorithm work?

Ok, let try a simple one.  Assume n is a 4-bit number, and the bits can be represented as a set such that n= {a,b,c,d}., where a,b,c.. can only have either 0 or 1.  The decimal value of the n is: 8*a + 4*b + 2*c + d.  Total number of bit one is: a + b + c +d.  

For example, for n=0b1101, a=1, b=1, c=0, d=1 (which in decimal is 8*1 + 4*1 + 2*0 + 1 = 13), and total number of bit one is a+b+c+d = 1 + 1 + 0 + 1 = 3.  So far so good?

Now, we know that to count the 1-bits is as simple as: a+b+c+d.  But, wait.... n itself is not a+b+c+d, but 8*a + 4*b + 2*c + d.  Ok, let's conquer this step-by-step.  

If we shift n one bit to the right the LSbit is gone and other numbers just divided by two, so n becomes 4*a + 2*b + c, right? Now substract this to the original n.
  
n      = 8*a + 4*b + 2*c + d
n>>1   = 0 +   4*a + 2*b + c
----------------------------- -
       = 0   + 4*a + 2*b + c + d

That's a good direction!  Now if  (n>>1)  is written in the 4-bit nibble it is actually 0 + 4a + 2b + c.  If I just want to subtract 4a and c, we have to mask out 2*b.  so we mask it with binary 0101 (0x5), so we get only 4a + c:

n              = 8*a + 4*b + 2*c + d
(n>>1)&0x5     = 4*a +   0 + c
---------------------------------- -
n -(n>>1)&0x05 = 4*a + 4*b + c + d

Now store this result back to n, so from now on n is 4*a + 4*b + c + d
To get c + d only, we mask n with 0b11 or n & 0x03
if we shift n above once, we get 0 + 2a + 2b + 0, but if shift it again it becomes a + b!
To make sure we only get the lowest two bits (a + b), we mask it to 0x03 or:

n & 0x03       = c + d
(n >> 2)&0x03) = a + b
------------------------ +

Nice! now we have this expression: a + b + c +d .

so, for 4-bit bit counting, we can use the expression:

n = n - (n>>1) & 0x05
bits = (n & 0x3) + (n>>2) & 0x3

Proof: as example above, n=13 (0b1101).  so a=1,b=1, c=0, d=1

n = 0b1101 - (0b0110) & 0b0101 = 0b1101 - 0b100 = 13 - 4 = 9 = 0b1001
then the next step:
bits =(0b1001 & 0b0011)  +  (0b1001>>2) & 0b0011, or bits = 0b0001 + (0b0010) & 0b0011
 or bits = 1 + 2 = 3 !!

For 32-bit, it is based on the same idea, except we have to do more.

Say n  has set of coefficients {a[0], a[1], ...., a[31]} to represent the number, so n = a[0]*2^31 + a[1]*2^30 + .... + a[15]*2^0

The mask should be 32-bit, so instead of 0x5, we use 0x55555555 = 0b0101010101...0101

n                   = 2^31*a[0] + ... + 2^1*a[30] + 2^0*a[31]
(n>>1)&0b0101..0101 = 2^30*a[0] + ... + 2^2*a[28] + 2^0*a[30]
-------------------------------------------------------------- -
n -(n>>1)&0x055555555 = 2^31*a[0] - 2^28*a[0] + (2^30*a[1] - 2^26*a[1]) + ... + 4*a[29] - 2*a[30] + a[31]


Let's review binary arithmetics.

23 - 22 = 8 - 4 = 4
216 - 215 = 65536 - 32768 = 32768
or: 2(a+1) - 2a = 2a

2(a+2) - 2a = 2a * 22 - 2a = 2a (4 - 1) = 3*2a

So the result is:

a[0] * (2^31 - 2^28) + a[1] * (2^30 - 2^26) + ..... + a[30] (4 -2) + a[31]
= 3*2^28*a[0] + .3*2^26*a[1].. + 2*a[30] + a[31]

n - n(>>1) & 0x055555555 = 3*2^28*a[0] + .3*2^26*a[1].. + 2*a[30] + a[31]

stored this as a new n.

n>>2 = 2^24*a[0] + 2^22*a[1] + ...2*a[28] + a[29]

The rest is actually manipulation to count this a[0] + a[1] + ...  + a[31]

A variant, but this one is invented by folks at MIT:

int bitcount(unsigned int n)
{
    int cnt = 0;
    register unsigned int tmp;
                     
    tmp = n - ((n >> 1) & 033333333333) - ((n >> 2) &    011111111111);

    return ((tmp + (tmp >>3)) & 0030707070707) % 63;
}

It uses the similar method. but with different approach (notice the number 01..., 0333... and 00307... are in octal.  We could use Hex version but then it is harder to remember)