mitchell vitez blog music art media dark mode

A Counting Monad

Imagine a version of the State monad where the state is an integer, the state always starts at 0, and there’s a single operation next which gets you the next number. Let’s call it the Count monad. We can write this kind of thing pretty easily by copying from State.

CountT looks like StateT except the state type is specialized to Integer.

newtype CountT m a =
  CountT { innerRunCountT :: Integer -> m (a, Integer) }

We called CountT’s function innerRunCountT because we’ll want to hide it. Instead, we can use this version that always initializes the count to 0:

runCountT :: CountT m a -> m (a, Integer)
runCountT m = innerRunCountT m 0

The Count monad is just CountT over Identity.

type Count = CountT Identity

runCount :: Count a -> (a, Integer)
runCount m = runIdentity $ runCountT m

The Functor, Applicative, and Monad instances for CountT m look just like the ones for StateT m, so we can skip over them here.

Example: Annotated Trees

Consider this definition of a binary tree.

data Tree a
  = Empty
  | Tree a (Tree a) (Tree a)

Let’s introduce a version of this type with an “annotation”. In our case, we’ll want to label each node with a unique integer, starting from 0.

data AnnTree ann val
  = AnnEmpty
  | AnnTree ann val (AnnTree ann val) (AnnTree ann val)

We can easily do this (and any other similar labeling task) with the Count monad.

label :: Tree a -> Count (AnnTree Integer a)
label Empty = pure AnnEmpty
label (Tree val left right) = do
  n <- next
  labeledLeft <- label left
  labeledRight <- label right
  pure $ AnnTree n val labeledLeft labeledRight

labeledTree :: AnnTree Integer a
labeledTree = fst . runCount . label $ _your_tree_here

Example: Function Call Counting

Next, consider this naive function that calculates the nth Fibonacci number.

fib :: Integer -> Integer
fib 0 = 0
fib 1 = 1
fib n = fib (n - 1) + fib (n - 2)

We want to find out how many times this function runs if we invoke it with fib 20. Sounds like a job for Count!

First, we wrap our function in Count.

countFib :: Integer -> Count Integer
countFib 0 = next >> pure 0
countFib 1 = next >> pure 1
countFib n = do
  next
  a <- countFib $ n - 1
  b <- countFib $ n - 2
  pure $ a + b

Then, we runCount $ countFib 20 and get (6765,21891). The first number is the 20th Fibonacci number, and the second number is the number of times countFib was invoked to find it.