I’ve implemented a HashMap or two in my day, but somehow never actually written a hash function. This seemed like it might be a gap in my knowledge that would be interesting to fill in.
MD5 is one of the simplest “real world” hash functions to implement. For extra fun, I wanted to create it right from the source—RFC 1321, written back in 1992. The relevant parts of that RFC total only seven pages. The remainder is dedicated to a reference implementation in C, which I resolved to ignore, preferring to work straight from the spec.
Numbers in headers indicate corresponding sections in RFC 1321.
Check out the source code if you like.
RFC 1321, section 1 states:
The algorithm takes as input a message of arbitrary length and produces as output a 128-bit “fingerprint” or “message digest” of the input.
We can translate this almost directly into code. We’ll take a reference to some slice of bytes as our argument, and return just enough bytes to get to 128 bits.
fn md5(input: &[u8]) -> [u8; 16]The next section of the RFC can be skimmed. It’s defining a word as 32 bits in LSB order, clarifying some mathematical notation, and explaining what various operators mean.
While I mostly followed the RFC from top to bottom, I often like to write a few tests before embarking on an implementation. These will act as an easy sanity check throughout, and provide a nice “we did it” moment the first time they all pass.
The RFC helpfully provides a small test suite in Appendix A section 5, with entries like these.
MD5 ("") = d41d8cd98f00b204e9800998ecf8427e
MD5 ("abc") = 900150983cd24fb0d6963f7d28e17f72
Let’s translate into test functions, checking that our
md5 function produces the right output.
#[test]
fn empty() {
assert_eq!(md5_hex(b""), "d41d8cd98f00b204e9800998ecf8427e");
}
#[test]
fn abc() {
assert_eq!(md5_hex(b"abc"), "900150983cd24fb0d6963f7d28e17f72");
}md5_hex is a helper function used to write
human-readable tests. It takes its slice of bytes and produces a
String of hex digits. Let’s mark it as belonging with the
tests using a #[cfg(test)] attribute, so the compiler
doesn’t complain about unused code.
#[cfg(test)]
fn md5_hex(input: &[u8]) -> String {
md5(input).map(|b| format!("{:02x}", b)).join("")
}With this handful of tests running, we’re ready to start writing the
md5 function itself.
Now we’re in the thick of it, essentially translating whatever the RFC says into actual code. First up is padding:
The message is “padded” (extended) so that its length (in bits) is congruent to 448, modulo 512. That is, the message is extended so that it is just 64 bits shy of being a multiple of 512 bits long. Padding is always performed, even if the length of the message is already congruent to 448, modulo 512.
Padding is performed as follows: a single “1” bit is appended to the message, and then “0” bits are appended so that the length in bits of the padded message becomes congruent to 448, modulo 512. In all, at least one bit and at most 512 bits are appended.
We first produce a version of the input where we can easily add more bytes.
let mut padded: Vec<u8> = input.to_vec();Add a single one bit, followed by some zeroes (0x80 is
0b10000000)…
padded.push(0x80);…then, add zero padding out to \(448 \pmod{512}\) bits, which is \(56 \pmod{64}\) bytes.
let padding_bytes =
(56usize.wrapping_sub(padded.len() % 64)) % 64;
padded.resize(padded.len() + padding_bytes, 0u8);A 64-bit representation of b (the length of the message before the padding bits were added) is appended to the result of the previous step.
Fair enough, as long as we remember to be little-endian, and to deal in bits rather than bytes.
padded
.extend_from_slice(&(input.len() as u64 * 8)
.to_le_bytes());This section also tells us:
Let M[0 … N-1] denote the words of the resulting message, where N is a multiple of 16.
let m: Vec<u32> = padded
.chunks(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
let n = m.len();
assert!(n % 16 == 0);A four-word buffer (A,B,C,D) is used to compute the message digest. Here each of A, B, C, D is a 32-bit register. These registers are initialized to the following values in hexadecimal, low-order bytes first):
word A: 01 23 45 67 word B: 89 ab cd ef word C: fe dc ba 98 word D: 76 54 32 10
Again, the only real issue we might run into is forgetting to be little-endian.
let mut a: u32 = 0x67452301;
let mut b: u32 = 0xefcdab89;
let mut c: u32 = 0x98badcfe;
let mut d: u32 = 0x10325476;This is the step where the bulk of the work gets done. The RFC describes four “auxiliary functions”, which we can translate quite directly.
F(X,Y,Z) = XY v not(X) Z G(X,Y,Z) = XZ v Y not(Z) H(X,Y,Z) = X xor Y xor Z I(X,Y,Z) = Y xor (X v not(Z))
fn aux_f(x: u32, y: u32, z: u32) -> u32 {
x & y | !x & z
}
fn aux_g(x: u32, y: u32, z: u32) -> u32 {
x & z | y & !z
}
fn aux_h(x: u32, y: u32, z: u32) -> u32 {
x ^ y ^ z
}
fn aux_i(x: u32, y: u32, z: u32) -> u32 {
y ^ (x | !z)
}This step uses a 64-element table T[1 … 64] constructed from the sine function. Let T[i] denote the i-th element of the table, which is equal to the integer part of 4294967296 times abs(sin(i)), where i is in radians.
Note that \(4294967296 = 2^{32}\), but I’m following the RFC as closely as possible so will stick with the longhand form.
Originally I wanted to implement this table as a
const fn returning [u32; 64], but
sin() won’t let that compile. We can instead write a normal
function…
fn sine_table(i: usize) -> u32 {
(4294967296.0 * (i as f64).sin().abs()) as u32
}…and “precompute” at the beginning of the md5 function.
This would perform worse than actual table precomputation if we were
computing many md5 hashes. However, I was more interested in an
understandable implementation, so was willing to trade off some
performance to not include 64 more magic numbers.
let t: [u32; 65] = std::array::from_fn(sine_table);The length is 65 because in the RFC the table T is 1-indexed, from 1 to 64 inclusive.
Next, the RFC starts describing the algorithm for processing 16-word blocks a ta time, with some helpful pseudocode.
/* Process each 16-word block. */
For i = 0 to N/16-1 do
/* Copy block i into X. */
For j = 0 to 15 do
Set X[j] to M[i*16+j].
end /* of loop on j */
/* Save A as AA, B as BB, C as CC, and D as DD. */
AA = A
BB = B
CC = C
DD = D
Translating as directly as possible, we end up with straightforwardly correct code:
for i in 0..=n / 16 - 1 {
let mut x = [0u32; 16];
for j in 0..=15 {
x[j] = m[i * 16 + j];
}
let aa = a;
let bb = b;
let cc = c;
let dd = d;
...
}The next part of the pseudocode doesn’t look much like any programming language I’ve seen.
/* Round 1. */
/* Let [abcd k s i] denote the operation
a = b + ((a + F(b,c,d) + X[k] + T[i]) <<< s). */
/* Do the following 16 operations. */
[ABCD 0 7 1] [DABC 1 12 2] [CDAB 2 17 3] [BCDA 3 22 4]
[ABCD 4 7 5] [DABC 5 12 6] [CDAB 6 17 7] [BCDA 7 22 8]
[ABCD 8 7 9] [DABC 9 12 10] [CDAB 10 17 11] [BCDA 11 22 12]
[ABCD 12 7 13] [DABC 13 12 14] [CDAB 14 17 15] [BCDA 15 22 16]
Each element in square brackets is an “operation”, which we can write
as a Rust function. We’ll need to track the correct order for
ABCD, the values of KSI, and the auxiliary
function, which changes from round to round. At least the internals of
the operation are straightforward to translate, but we have to remember
to use wrapping_add rather than + since MD5 is
reliant on that behavior.
fn operation(
a: u32,
b: u32,
c: u32,
d: u32,
k: usize,
s: u32,
i: usize,
x: &[u32],
t: &[u32],
aux_func: fn(u32, u32, u32) -> u32,
) -> u32 {
b.wrapping_add(
a.wrapping_add(aux_func(b, c, d))
.wrapping_add(x[k])
.wrapping_add(t[i])
.rotate_left(s),
)
}Originally, I used a list of 64 calls to the operation
function, 16 for each round. This gets tedious fairly quickly:
a = operation(a, b, c, d, 0, 7, 1, &x, &t, f);
d = operation(d, a, b, c, 1, 12, 2, &x, &t, f);
c = operation(c, d, a, b, 2, 17, 3, &x, &t, f);
...One important insight is that each new operation is shifting
ABCD to the right by one.
ABCD
DABC
CDAB
BCDA
I liked the a, b, c,
d variable names for clarity, so couldn’t actually do bit
shifting here, but we can represent the shifts via reassignments:
(a, b, c, d) = (d, a, b, c);Then, we always put the result of our operation into a,
but a is really aliased to be any one of a,
b, c, or d.
After noticing this, we can factor out only the parts that differ
between rounds (the auxiliary functions) and the parts that differ
between operations (the values for k, s, and
i). Let’s make a data structure for that information.
struct Round {
aux: fn(u32, u32, u32) -> u32,
ksi: &'static [(usize, u32, usize); 16],
}Our implementation becomes more data-oriented, listing out all the KSI numbers for each operation…
const ROUNDS: &[Round; 4] = &[
Round {
aux: aux_f,
ksi: &[
(0, 7, 1),
(1, 12, 2),
(2, 17, 3),
(3, 22, 4),
(4, 7, 5),
...…and looping through that data to massively reduce the amount of boilerplate code.
// run all 4 rounds
for round in ROUNDS {
// run all 16 operations for that round
for &(k, s, i) in round.ksi {
a = operation(a, b, c, d, k, s, i, &x, &t, round.aux);
(a, b, c, d) = (d, a, b, c);
}
}Finally, we update the registers, which is fairly straightforward.
/* Then perform the following additions. (That is increment each of the four registers by the value it had before this block was started.) */ A = A + AA B = B + BB C = C + CC D = D + DD
a = a.wrapping_add(aa);
b = b.wrapping_add(bb);
c = c.wrapping_add(cc);
d = d.wrapping_add(dd);It’s time to return an actual result!
The message digest produced as output is A, B, C, D. That is, we begin with the low-order byte of A, and end with the high-order byte of D.
All that’s left to do is move the resulting values into the right places with a little byte manipulation.
let mut result = [0u8; 16];
for (i, word) in [a, b, c, d].iter().enumerate() {
result[i * 4..(i + 1) * 4].copy_from_slice(&word.to_le_bytes());
}
resultLike the RFC says:
This completes the description of MD5.