mitchell vitez

dark mode

blog about music art media

resume email github

Haskell-style List Comprehension Syntax with Rust Macros

I recently learned a little bit about Rust’s declarative macros, and wanted to try the feature out on some interesting bit of syntax. What I came up with was a translation of Haskell’s list comprehensions that would act like a fancy form of vec!.

This is one of those “I recently learned something fun and want to show it off” kinds of posts. If you’d like to learn much more about macros, I’d recommend The Little Book of Rust Macros.

Here are a few example Haskell list comprehensions:

-- squares
[x * x | x <- [1..4]]

-- tensor product
[a * b | a <- [2..3], b <- [2..4]]

-- tensor product, other order
[a * b | b <- [2..4], a <- [2..3]]

Unfortunately, we’ll run into a couple limitations in declarative macros and differences between the languages that cause our final syntax to be a little different. For example, Rust’s macro system wants a separator such as ; to appear after parsing an expression, and in Haskell [1..5] is inclusive, but Rust’s 1..5 excludes the 5.

However, since I can see into the future and know how things turn out, I can hand you a bunch of tests now that use the final syntax, and we can work towards getting cargo test to pass.

#[cfg(test)]

#[test]
fn test_squares() {
    assert_eq!(
        listcomp![x * x; | x <- 1..5],
        vec![1, 4, 9, 16]
    )
}
#[test]
fn test_tensor_product() {
    assert_eq!(
        listcomp![a * b; | a <- 2..4, b <- 2..5],
        vec![4, 6, 8, 6, 9, 12]
    )
}
#[test]
fn test_tensor_product_other_order() {
    assert_eq!(
        listcomp![a * b; | b <- 2..5, a <- 2..4],
        vec![4, 6, 6, 9, 8, 12]
    )
}

Writing the first macro rule

As shown in the test, we’ll want to create some macro rules for a macro called listcomp, and export it.

#[macro_export]
macro_rules! listcomp {
    // TODO
}

If we want to write a decent macro, we should think about the structure of the syntax being piped into it. Here we have basically two parts: an expression before the |, and a comma-separated list of <- expressions. Let’s deal with the pre-| part first.

We can match the initial expression with syntax like $name:expr, but here I can’t think of a great name, so let’s go with $expr. It’s a requirement that expressions be separated by one of a small list of tokens. ; seems like the most reasonable one here, so let’s go with that. Then we can match our | token. In Rust, macros operate on “token trees”. We can match any token tree using tt. Let’s call those the $rest and use $( )+ to grab one or more of them. This basically grabs the entire rest of the macro invocation for us to work with later.

macro_rules! listcomp {
    [$expr:expr; | $($rest:tt)+] =>
      ( /* TODO */ )
}

We now have access to an expression $expr and the remaining tokens $rest. We’ll wrap a recursive call to listcomp! by creating and then finally returning a mutable vector. Our recursive call passes along the result, the $expr which we grabbed in this rule, and the rest of the tokens, so another listcomp! rule can use all of them.

macro_rules! listcomp {
    [$expr:expr; | $($rest:tt)+] => ({
        let mut result = vec![];
        listcomp!(result; $expr; $($rest)+);
        result
    });
}

Adding <- expressions

Now we can write another rule. Since we’ve already dealt with the initial expression, now we need to deal with the list of expressions that looks like a <- as, b <- bs, c <- cs. We can split our dealings with this list into handling a base case of one expression (since there will always be at least one) and a recursive case that handles an expression along with the rest of the list.

In each case, we’ll want to match the $result and the $expr that we passed along from our initial recursive invocation of listcomp!. Since $result corresponds to the variable name result, it’s an identifier (ident). We match the a <- as syntax into a $var for a and a $list for as.

    ($result:ident; $expr:expr; $var:ident <- $list:expr) =>
      { /* TODO */ }

Given those pieces, and knowing that $list is an expression like 1..10, we can build the base-case code that takes each $var (a) in the $list (as) and pushes the value of the comprehension’s $expr (in the context of that specific value for $var) to the vector $result.

    ($result:ident; $expr:expr; $var:ident <- $list:expr) => {
        for $var in $list {
            $result.push($expr);
        }
    };

Finally, as you might imagine, the recursive case makes use of $($rest:tt)+ and the comma-separated syntax to recursively build up nested for loops by calling into listcomp! until the base case is hit.

    ( $result:ident;
      $expr:expr;
      $var:ident <- $list:expr, $($rest:tt)+
    ) => {
        for $var in $list {
            listcomp![$result; $expr; $($rest)+];
        }
    };

Adding guards

We now have a pretty useful chunk of list comprehension syntax built up, but we can add even more in the form of guards. Here are some Haskell examples:

-- guard
[x * x | x <- [1..4], x /= 3]
-- multiple guards
[x * x | x <- [1..9], x > 5, x < 8]
-- guards for multiple variables
[x + y | x <- [1..9], y <- [3..7], y > 5, x < 5]

Once again, we can write these out as tests. One tiny difference to note is between the /= and != operators.

#[test]
fn test_guard() {
    assert_eq!(
      listcomp![x * x; | x <- 1..5, x != 3],
      vec![1, 4, 16]
    )
}
#[test]
fn test_multiple_guards() {
    assert_eq!(
        listcomp![x * x; | x <- 1..10, x > 5, x < 8],
        vec![36, 49]
    )
}
#[test]
fn test_multiple_variable_guards() {
    assert_eq!(
        listcomp![x + y; | x <- 1..10, y <- 3..8, y > 5, x < 5],
        vec![7, 8, 8, 9, 9, 10, 10, 11]
    )
}

Now, alongside where we’re setting up an empty result with vec!, we should set a default guard. We can say that by default, the guard is true, since it lets everything through. Then, we want to pass our guard around to all the other existing macro rules.

        let mut result = vec![];
        let guards = true;
        listcomp!(result; $expr; guards; $($rest)+);

Our base case rule for the a <- as case should now check the guard before pushing to the result vector.

    ($result:ident; $expr:expr; $guards:ident; $var:ident <- $list:expr) => {
        for $var in $list {
            if $guards {
                $result.push($expr);
            }
        }
    };

If we’re in the recursive case and see an a <- as expression, we just pass along the guard unchanged.

    ($result:ident; $expr:expr; $guards:ident; $var:ident <- $list:expr, $($rest:tt)+) => {
        for $var in $list {
            listcomp![$result; $expr; $guards; $($rest)+];
        }
    };

If our base case hits and we’re on a guard expression, we take the value of that guard anded with the other guards as the current guard.

    ($result:ident; $expr:expr; $guards:ident; $guard:expr) => {
        if $guards && $guard {
            $result.push($expr);
        }
    };

Finally, if we’re in the recursive case, we && the current guard and the old guards together, and pass that along as the new guard. We’ve successfully added guards to our comprehensions.

    ($result:ident; $expr:expr; $guards:ident; $guard:expr, $($rest:tt)+) => {
        let newguard = $guards && $guard;
        listcomp![$result; $expr; newguard; $($rest)+];
    };

Note that a comprehension with a guard before the relevant binding is introduced will fail:

listcomp![x; | x == 3, x <- 1..5]

error[E0425]: cannot find value `x` in this scope
  --> src/lib.rs:63:35
   |
63 |     assert_eq!(listcomp![x; | x == 3, x <- 1..5], vec![1, 4, 16])
   |                               ^ not found in this scope

However, this is the same behavior as in Haskell list comprehensions, so it seems alright to me.

[x | x == 3, x <- [1..4]]

<interactive>:1:6: error: Variable not in scope: x

Trying it out

If we add a function with this listcomp! invocation…

listcomp![x + y + z; | x <- 1..10, y <- 3..8, y > 5, x < 5, z <- 1..3, z > 2];

…and compile with rustc +nightly -Zunpretty=expanded, we can see what the macro actually expands to.

{
    let mut result =

        ::alloc::vec::Vec::new();
    let guards = true;
    for x in 1..10 {
        for y in 3..8 {
            let newguard = guards && y > 5;
            let newguard = newguard && x < 5;
            for z in 1..3 {
                if newguard && z > 2 { result.push(x + y + z); };
            };
            ;
        };
    };
    result
};

It’s not the prettiest code ever (metaprogramming rarely is), but it gets the job done.