mitchell vitez blog music art media dark mode

Type-Enforced Exponential Trees

For our purposes here, an exponential tree is a tree where a nonempty node of the tree at depth \(d\) has \(2^d\) children. For example, the node at depth 1 has 2, at depth 2 has 4, and at depth 3 has 8. (The root node is assumed to have depth 1.)

Today we’re going to be coming up with a way to create exponential trees such that this property is checked in the types. That is, we’re going to attempt to make it impossible for anyone to construct an invalid tree with the tree type we come up with. I’ll be revealing the code for this in bits and pieces, but if you’d like to see the whole implementation at once, please check out the github repo. You might also want to reference it if you’re following along—you may just be missing a language extension or something.

First, let’s talk about plain old trees. Here’s a type for a rose tree in Haskell:

data RoseTree a = Empty | RoseTree a [RoseTree a]

Because each node has a list of nodes, any level of the tree can have any number of elements. There’s no bound on the size of these lists, so that’s a good place for us to think about adding a bound in our type. Let’s take a first stab at a type for exponential trees. Like the rose trees, we’ll want some value of type a in each node, but let’s also store a depth counter as d so we can get that \(2^d\) elements property. We’ll leave out what we’re not sure about yet with a _ hole.

data ExponentialTree a d
  = Empty
  | ExponentialTree a _

What goes in the hole? We need some kind of fixed-length list…. VecPeano from Data.Vector.Fixed should do the trick! Let’s fill that in, and leave a hole for the length of the vector. The depth of the next sublevel of the tree is \(d + 1\), or in Peano numerals the successor to d, S d.

data ExponentialTree a d
  = Empty
  | ExponentialTree a (VecPeano _ (ExponentialTree a (S d)))

This new hole is where we want to add in our constraint. What we want is for a sublevel at depth \(d\) to have \(2^d\) elements. This is all fine and dandy, but how do we do exponentiation in types? In this case, we chose VecPeano for our implementation, which uses Peano numerals to encode natural numbers in types. So let’s take a quick detour and figure out how to do \(2^d\) using Peano numerals.

The place where VecPeano’s numbers are defined is Data.Vector.Fixed.Cont. Peano numerals are either zero, or the successor to another number. In this way, we can inductively define all the natural numbers. Here’s the definition for them from that library converting from natural numbers (Nat) to PeanoNum (lightly modified).

type family Peano (n :: Nat) :: PeanoNum where
  Peano 0 = Z
  Peano n = S (Peano (n - 1))

Type families are sort of a way to sneak in extra knowledge about data constructors to your types. In normal usage, types are either parametric but give us no knowledge about their data constructors (think id :: a -> a), or the data constructors are fully known (think of pattern matching a Maybe…it’s either Just or Nothing). Here we’re saying that a PeanoNum is a type constructed either as Z or as some nested succession that bottoms out in Z.

This lets us represent the natural numbers. Zero is Z, one is S Z, two is S (S Z)), three is S (S (S Z)). You get the idea. Having these numbers in types is nice, because the typechecker checks the rules that a PeanoVec (S (S Z)) must have exactly two elements at compile time. We cannot construct a PeanoVec (S (S Z)) with any other number of elements.

Data.Vector.Fixed.Cont also helpfully provides us with Peano addition, so we don’t have to implement that from scratch. Zero plus any natural number equals that number, so that’s our base case. Our recursive case subtracts one from n (by unwrapping it from its former S n) and adds one to k (by taking the successor of k). Eventually, n will be zero, and we can return the value of the addition. Here’s the definition in code:

type family Add (n :: PeanoNum) (k :: PeanoNum) :: PeanoNum where
  Add Z k = k
  Add (S n) k = S (Add n k)

As you can see, this satisfies the properties of addition on natural numbers. \(0 + n = n\), and \((n + 1) + k = 1 + (n + k)\). These definitions are all we need for addition. In our recursive case, eventually the first number bottoms out in Z, and so the addition works!

We can reason by analogy to come up with Peano multiplication. Zero times any number is zero. We can use addition in our recursive case, satisfying the equation \((n + 1) \cdot m = m + (n \cdot m)\).

type family Mul (n :: PeanoNum) (m :: PeanoNum) :: PeanoNum where
  Mul Z _ = Z
  Mul (S n) m = Add m (Mul n m)

Once we have multiplication, we can do something very similar to create exponentiation. (Note that we’re sidestepping the issue of what \(0^0\) should be here.) Our equations are \(0^n = 0 \mid n \geq 1\), \(a ^ 0 = 1\), and \(a ^ {b+1} = a \cdot (a ^ b)\).

type family Exp (a :: PeanoNum) (b :: PeanoNum) :: PeanoNum where
  Exp Z (S n) = Z
  Exp a Z = S Z
  Exp a (S b) = Mul a (Exp a b)

Fantastic! Now that we have exponentiation in types, we can write \(2^d\) as Exp (S (S Z)) d in our types. Let’s go back to our definition of ExponentialTree, with the hole.

data ExponentialTree a d
  = Empty
  | ExponentialTree a (VecPeano _ (ExponentialTree a (S d)))

We can fill in the hole with \(2^d\).

data ExponentialTree a d
  = Empty
  | ExponentialTree a (VecPeano (Exp (S (S Z)) d) (ExponentialTree a (S d)))

Finally we have a type-safe exponential tree! Let’s quickly create a Show instance for it, so it’s easier to mess around with in ghci. Because Data.Vector.Fixed doesn’t provide it, we’ll also have to make a Show instance for VecPeano, but that shouldn’t be too bad.

instance Show a => Show (ExponentialTree a d) where
  show Empty = "Empty"
  show (ExponentialTree x v) =
    "ExponentialTree " ++ show x ++ " (" ++ show v ++ ")"

instance Show x => Show (VecPeano n x) where
  show Nil = "Nil"
  show (Cons x y) = "Cons (" ++ show x ++ ") (" ++ show y ++ ")"

We can now demonstrate that the typechecker will catch invalid trees for us. Let’s create the simplest non-Empty tree. It’ll have a root node with a value, and then has to point to 2 (because \(2 = 2^1\)) empty nodes. The type signature ensures a starting depth of 1.

ExponentialTree 7 (Cons Empty (Cons Empty Nil))
  :: ExponentialTree Int (Peano 1)
tree E1 Empty E2 Empty 7 7 7->E1 7->E2

If we try to create either of the following trees which aren’t exponential trees, we get a compilation failure. The first has too few nodes in the first layer down, and the second example has too many.

ExponentialTree 7 (Cons Empty Nil)
  :: ExponentialTree Int (Peano 1)
ExponentialTree 7 (Cons Empty (Cons Empty (Cons Empty Nil)))
  :: ExponentialTree Int (Peano 1)

Let’s build one slightly more complicated tree. Feel free to mess around with the numbers of nodes and make sure that our constraints are correctly enforced.

ExponentialTree 1
  (Cons Empty 
  (Cons (ExponentialTree 2 
    (Cons Empty 
    (Cons Empty
    (Cons Empty 
    (Cons Empty
  Nil)))))
Nil))
  :: ExponentialTree Int (Peano 1)
tree E1 Empty E2 Empty E3 Empty E4 Empty E5 Empty 1 1 1->E1 2 2 1->2 2->E2 2->E3 2->E4 2->E5

We can also define a few helper functions to help us build and play with these trees. The most trivial one is empty:

empty :: ExponentialTree a d
empty = Empty

If we had done the following instead, we would only be allowing Empty nodes in trees with depth one, enforced by the typechecker. (I made this mistake while implementing these trees, and was wondering why valid trees weren’t passing the typechecker.) You probably do, however, want to enforce a depth-one constraint when constructing the top-level node of non-empty trees.

empty :: ExponentialTree a (Peano 1)
empty = Empty

This is just to say it pays to be careful when you’re dealing with type-level programming! You could try implementing functions like insert or member or size to manipulate these trees.

Another fun exercise might be to come up with an ASCII pretty-printer. Try it out! (I recommend adapting some of the code in Data.Tree.Pretty.) A sample solution is in the spoiler block below.

-- Based on Data.Tree.Pretty
putTree :: Show a => ExponentialTree a d -> IO ()
putTree = putStr . unlines . display

display :: Show a => ExponentialTree a d -> [String]
display Empty = ["Empty"]
display (ExponentialTree x v) =
  show x : drawSubTrees v
    where
      drawSubTrees :: Show a => VecPeano n (ExponentialTree a d) -> [String]
      drawSubTrees Nil = []
      drawSubTrees (Cons x Nil) =
        "|" : shift "`- " "  " (display x)
      drawSubTrees (Cons x xs) =
        "|" : shift "+- " "| " (display x) ++ drawSubTrees xs

      shift first other = Prelude.zipWith (++) (first : repeat other)

Here’s another link to the github repo.