Challenge 4

Index | Up


Table of Contents

1. Original assembly

Original challenge link: Click me.

This is the original function, adapted to fit the nasm syntax.

func:
    mov     edx, edi
    shr     edx, 1
    and     edx, 0x55555555         ; 0b01010101010101010101010101010101
    sub     edi, edx

    mov     eax, edi
    shr     edi, 2
    and     eax, 0x33333333         ; 0b00110011001100110011001100110011
    and     edi, 0x33333333         ; 0b00110011001100110011001100110011
    add     edi, eax

    mov     eax, edi
    shr     eax, 4
    add     eax, edi
    and     eax, 0x0f0f0f0f         ; 0b00001111000011110000111100001111

    imul    eax, eax, 0x01010101    ; 0b00000001000000010000000100000001
    shr     eax, 24
    ret

2. Analysis

First of all, it loads the function argument into edx, and shifts it one bit to the right. Then it AND's the edx register to keep only the even bits (0, 2, …, 30). Now, the bits at odd positions (1, 3, …, 31) are being moved to the adjacent position.

This value in the edx register will is then subtracted from the original argument in edi.

mov     edx, edi
shr     edx, 1
and     edx, 0x55555555 ; 0b01010101010101010101010101010101
sub     edi, edx

It then adds the even pairs of bits (eax) to the odd pairs of bits (edi).

mov     eax, edi
shr     edi, 2
and     eax, 0x33333333 ; 0b00110011001100110011001100110011
and     edi, 0x33333333 ; 0b00110011001100110011001100110011
add     edi, eax

After that, it copies the result back to eax, shifts it 4 bits to the right, and adds the original value stored in edi.

mov     eax, edi
shr     eax, 4
add     eax, edi

Then, it discards the even nibbles of eax, and multiplies that by 0x01010101. Keep in mind that since it's specifying eax as the destination register, it will truncate the value and keep the lower 32 bits.

and     eax, 0x0f0f0f0f      ; 0b00001111000011110000111100001111
imul    eax, eax, 0x01010101 ; 0b00000001000000010000000100000001

It then shifts the result 24 bits to the right, before returning it.

shr     eax, 24
ret

Let's test it with some other numbers to see how the result changes:

Hex input Binary input Output
0x00 0b00000000 0x0
0x01 0b00000001 0x1
0x02 0b00000010 0x1
0x03 0b00000011 0x2
0x0F 0b00001111 0x4
0xFF 0b11111111 0x8
0xDEADBEEF 0b11011110101011011011111011101111 0x18

With this, can see that it's a function for counting the number of set bits in our input.

3. Why does this work?

After discovering what this functions does, I decided to look a bit more into how it works. I discovered this stackoverflow link. Turns out this is called calculating the Hamming weight.

The assembly function is the simplified version of this one:

#include <stdint.h>
#include <stdio.h>

int count_bits(uint32_t x) {
    /* Put count of each 2 bits into those 2 bits */
    x = (x & 0x55555555) + ((x >> 1) & 0x55555555);

    /* Put count of each 4 bits into those 4 bits */
    x = (x & 0x33333333) + ((x >> 2) & 0x33333333);

    /* Put count of each 8 bits into those 8 bits */
    x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);

    /* If (x <= 0xFF), propagates lower byte to the upper ones. I am not sure
     * why this works for bigger values. */
    x *= 0x01010101;
    return x >> 24;
}

printf("%d\n", count_bits(0x000000AB)); /* 5 */
printf("%d\n", count_bits(0xDEADBEEF)); /* 24 */

Keep in mind that:

Hex Binary
0x55555555 0b01010101010101010101010101010101
0x33333333 0b00110011001100110011001100110011
0x0f0f0f0f 0b00001111000011110000111100001111

The last multiplication and shifting step is used to skip the following lines:

x = (x & 0x00ff00ff) + ((x >>  8) & 0x00ff00ff);
x = (x & 0x0000ffff) + ((x >> 16) & 0x0000ffff);

I still can't fully understand how this works, though.

4. Why do some versions use the 0x01010101 constant?

In the assembly code, the value was used for multiplication, right before shifting and returning.

imul    eax, eax, 0x01010101

Multiplying small values (less or equal than 0xFF) by 0x01010101 clones the bits to the remaining 3 bytes of the DWORD. For example:

      0xAB                          0b10101011
0x01010101  0b00000001000000010000000100000001
---------------------------------------------- (MUL)
0xABABABAB  0b10101011101010111010101110101011

Right after this, it shifts the value 24 bits to the right, discarding the lower 3 bytes, and returning the higher one.

shr     eax, 24
ret

Let's compare it to the ARM assembly, one of the examples without the 0x01010101 constant:

...
ADD     r0, r0, r0, LSL #16
ADD     r0, r0, r0, LSL #8
LSR     r0, r0, #24
BX      lr

The first instruction is shifting the value at r0 16 bits to the left, adding that to r0, and storing that in r0. The next instruction is doing the same but shifting it 8 bits to the left, instead of 16. Finally, it's doing the 24 bit right shift we saw in other examples.

The last instruction is returning from the procedure, as explained in the ARM documentation:

BX - Branch and exchange instruction set

Syntax: BX Rm

The BX instruction causes a branch to the address contained in Rm, and exchanges the instruction set, if required:

  • If bit[0] of Rm is 0, the processor changes to, or remains in, ARM state
  • If bit[0] of Rm is 1, the processor changes to, or remains in, Thumb state.

Translated to C code:

r0 = r0 + (r0 << 16);
r0 = r0 + (r0 << 8);
r0 = r0 >> 24;
return r0;

Now that we know what both of them are doing before returning, it's not hard to see that these two shifts from ARM are doing the same as the 0x01010101 multiplication:

Base:               0x000000AB
After first shift:  0x00AB0000
                    ---------- (ADD)
After adding:       0x00AB00AB
After second shift: 0xAB00AB00
                    ---------- (ADD)
After adding again: 0xABABABAB

So the compiler probably made a choice depending on the performance of shifting and adding vs. multiplying by a constant.

My final question, however, is: Why would you do copy the lower byte to the other ones, right before returning the upper one?

5. C translation

These are two C translation of the x86 assembly code, one uses multiplication for the last step, while the other uses shifting.

#include <stdint.h>
#include <stdio.h>

uint32_t count_bits(uint32_t num) {
    num -= (num >> 1) & 0x55555555;                       /* First 4 instructions */
    num = ((num >> 2) & 0x33333333) + (num & 0x33333333); /* Next 5 instructions */
    num += (num >> 4);                                    /* Next 3 instructions */
    num &= 0x0f0f0f0f;                                    /* Next instruction */
    num = (num * 0x01010101) >> 24;                       /* Next 2 instructions */
    return num;                                           /* Last instruction */
}

/* Without the 0x01010101 constant */
uint32_t count_bits_shifting(uint32_t num) {
    num -= (num >> 1) & 0x55555555;
    num = ((num >> 2) & 0x33333333) + (num & 0x33333333);
    num += (num >> 4);
    num &= 0x0f0f0f0f;

    num += (num << 16);
    num += (num << 8);
    num >>= 24;
    return num;
}

printf("%d\n", count_bits(0x000000AB));          /* 5 */
printf("%d\n", count_bits(0xDEADBEEF));          /* 24 */
printf("%d\n", count_bits_shifting(0x000000AB)); /* 5 */
printf("%d\n", count_bits_shifting(0xDEADBEEF)); /* 24 */

6. Iterative C version

A simpler (but probably slower) way of doing it with a for loop:

#include <stdint.h>
#include <stdio.h>

int count_bits(uint32_t x) {
    int ret = 0;

    for (; x > 0; x >>= 1)
        if (x & 1)
            ret++;

    return ret;
}

printf("%d\n", count_bits(0x000000AB)); /* 5 */
printf("%d\n", count_bits(0xDEADBEEF)); /* 24 */