Challenge 4
Table of Contents
1. Original assembly
Original challenge link: Click me.
This is the original function, adapted to fit the nasm syntax.
func: mov edx, edi shr edx, 1 and edx, 0x55555555 ; 0b01010101010101010101010101010101 sub edi, edx mov eax, edi shr edi, 2 and eax, 0x33333333 ; 0b00110011001100110011001100110011 and edi, 0x33333333 ; 0b00110011001100110011001100110011 add edi, eax mov eax, edi shr eax, 4 add eax, edi and eax, 0x0f0f0f0f ; 0b00001111000011110000111100001111 imul eax, eax, 0x01010101 ; 0b00000001000000010000000100000001 shr eax, 24 ret
2. Analysis
First of all, it loads the function argument into edx
, and shifts it one bit to
the right. Then it AND's the edx
register to keep only the even bits (0, 2, …,
30). Now, the bits at odd positions (1, 3, …, 31) are being moved to the
adjacent position.
This value in the edx
register will is then subtracted from the original
argument in edi
.
mov edx, edi shr edx, 1 and edx, 0x55555555 ; 0b01010101010101010101010101010101 sub edi, edx
It then adds the even pairs of bits (eax
) to the odd pairs of bits (edi
).
mov eax, edi shr edi, 2 and eax, 0x33333333 ; 0b00110011001100110011001100110011 and edi, 0x33333333 ; 0b00110011001100110011001100110011 add edi, eax
After that, it copies the result back to eax
, shifts it 4 bits to the right, and
adds the original value stored in edi
.
mov eax, edi shr eax, 4 add eax, edi
Then, it discards the even nibbles of eax
, and multiplies that by
0x01010101
. Keep in mind that since it's specifying eax
as the destination
register, it will truncate the value and keep the lower 32 bits.
and eax, 0x0f0f0f0f ; 0b00001111000011110000111100001111 imul eax, eax, 0x01010101 ; 0b00000001000000010000000100000001
It then shifts the result 24 bits to the right, before returning it.
shr eax, 24 ret
Let's test it with some other numbers to see how the result changes:
Hex input | Binary input | Output |
---|---|---|
0x00 |
0b00000000 |
0x0 |
0x01 |
0b00000001 |
0x1 |
0x02 |
0b00000010 |
0x1 |
0x03 |
0b00000011 |
0x2 |
0x0F |
0b00001111 |
0x4 |
0xFF |
0b11111111 |
0x8 |
0xDEADBEEF |
0b11011110101011011011111011101111 |
0x18 |
With this, can see that it's a function for counting the number of set bits in our input.
3. Why does this work?
After discovering what this functions does, I decided to look a bit more into how it works. I discovered this stackoverflow link. Turns out this is called calculating the Hamming weight.
The assembly function is the simplified version of this one:
#include <stdint.h> #include <stdio.h> int count_bits(uint32_t x) { /* Put count of each 2 bits into those 2 bits */ x = (x & 0x55555555) + ((x >> 1) & 0x55555555); /* Put count of each 4 bits into those 4 bits */ x = (x & 0x33333333) + ((x >> 2) & 0x33333333); /* Put count of each 8 bits into those 8 bits */ x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f); /* If (x <= 0xFF), propagates lower byte to the upper ones. I am not sure * why this works for bigger values. */ x *= 0x01010101; return x >> 24; } printf("%d\n", count_bits(0x000000AB)); /* 5 */ printf("%d\n", count_bits(0xDEADBEEF)); /* 24 */
Keep in mind that:
Hex | Binary |
---|---|
0x55555555 |
0b01010101010101010101010101010101 |
0x33333333 |
0b00110011001100110011001100110011 |
0x0f0f0f0f |
0b00001111000011110000111100001111 |
The last multiplication and shifting step is used to skip the following lines:
x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff); x = (x & 0x0000ffff) + ((x >> 16) & 0x0000ffff);
I still can't fully understand how this works, though.
4. Why do some versions use the 0x01010101
constant?
In the assembly code, the value was used for multiplication, right before shifting and returning.
imul eax, eax, 0x01010101
Multiplying small values (less or equal than 0xFF
) by 0x01010101
clones the bits
to the remaining 3 bytes of the DWORD
. For example:
0xAB 0b10101011 0x01010101 0b00000001000000010000000100000001 ---------------------------------------------- (MUL) 0xABABABAB 0b10101011101010111010101110101011
Right after this, it shifts the value 24 bits to the right, discarding the lower 3 bytes, and returning the higher one.
shr eax, 24 ret
Let's compare it to the ARM assembly, one of the examples without the 0x01010101
constant:
... ADD r0, r0, r0, LSL #16 ADD r0, r0, r0, LSL #8 LSR r0, r0, #24 BX lr
The first instruction is shifting the value at r0
16 bits to the left, adding
that to r0
, and storing that in r0
. The next instruction is doing the same but
shifting it 8 bits to the left, instead of 16. Finally, it's doing the 24 bit
right shift we saw in other examples.
The last instruction is returning from the procedure, as explained in the ARM documentation:
BX - Branch and exchange instruction set
Syntax:
BX Rm
The
BX
instruction causes a branch to the address contained inRm
, and exchanges the instruction set, if required:
- If bit[0] of
Rm
is 0, the processor changes to, or remains in, ARM state- If bit[0] of
Rm
is 1, the processor changes to, or remains in, Thumb state.
Translated to C code:
r0 = r0 + (r0 << 16); r0 = r0 + (r0 << 8); r0 = r0 >> 24; return r0;
Now that we know what both of them are doing before returning, it's not hard to
see that these two shifts from ARM are doing the same as the 0x01010101
multiplication:
Base: 0x000000AB After first shift: 0x00AB0000 ---------- (ADD) After adding: 0x00AB00AB After second shift: 0xAB00AB00 ---------- (ADD) After adding again: 0xABABABAB
So the compiler probably made a choice depending on the performance of shifting and adding vs. multiplying by a constant.
My final question, however, is: Why would you do copy the lower byte to the other ones, right before returning the upper one?
5. C translation
These are two C translation of the x86 assembly code, one uses multiplication for the last step, while the other uses shifting.
#include <stdint.h> #include <stdio.h> uint32_t count_bits(uint32_t num) { num -= (num >> 1) & 0x55555555; /* First 4 instructions */ num = ((num >> 2) & 0x33333333) + (num & 0x33333333); /* Next 5 instructions */ num += (num >> 4); /* Next 3 instructions */ num &= 0x0f0f0f0f; /* Next instruction */ num = (num * 0x01010101) >> 24; /* Next 2 instructions */ return num; /* Last instruction */ } /* Without the 0x01010101 constant */ uint32_t count_bits_shifting(uint32_t num) { num -= (num >> 1) & 0x55555555; num = ((num >> 2) & 0x33333333) + (num & 0x33333333); num += (num >> 4); num &= 0x0f0f0f0f; num += (num << 16); num += (num << 8); num >>= 24; return num; } printf("%d\n", count_bits(0x000000AB)); /* 5 */ printf("%d\n", count_bits(0xDEADBEEF)); /* 24 */ printf("%d\n", count_bits_shifting(0x000000AB)); /* 5 */ printf("%d\n", count_bits_shifting(0xDEADBEEF)); /* 24 */
6. Iterative C version
A simpler (but probably slower) way of doing it with a for
loop:
#include <stdint.h> #include <stdio.h> int count_bits(uint32_t x) { int ret = 0; for (; x > 0; x >>= 1) if (x & 1) ret++; return ret; } printf("%d\n", count_bits(0x000000AB)); /* 5 */ printf("%d\n", count_bits(0xDEADBEEF)); /* 24 */