mitchell vitez blog music art media dark mode

Magic Square Generation

This is an exploration of backtracking algorithms and checking possible solutions in a search space. Backtracking isn’t the fastest way to make magic squares (and if you want sides bigger than 3, you’ll pretty much have to use a better algorithm) but it is fun to play with. We’ll be writing this in Rust, so it’ll be fast enough for our purposes anyways.

Full code from this post is here.

First, we’ll set the square side length at 3, and calculate the magic sum that each row, column, and diagonal will add up to. In this case, it’s 15. Let’s do all this as constant functions so it’s taken care of for us at compile time.

const SIZE: usize = 3;
const SUM: u32 = magic_constant(SIZE as u32);

const fn magic_constant(n: u32) -> u32 {
    return n * (n * n + 1) / 2;
}

We’ll also need a way to create a \(\mathrm{SIZE} \times \mathrm{SIZE}\) grid, that we can fill it in with numbers. Let’s create an initial grid of u32, filled in with zeroes, as a const fn.

const fn initial() -> [[u32; SIZE]; SIZE] {
    return [[0; SIZE]; SIZE];
}

It’d be nice to have a way to print out the grid for debugging, and also for investigating our solutions later. Let’s define a nice way to print out our grid.

fn display(grid: &[[u32; SIZE]; SIZE]) {
    for row in grid.iter() {
        for col in row.iter() {
            print!("{} ", col);
        }
        println!();
    }
    println!();
}

Backtracking algorithms have some check to see if a solution for the problem has been found. In this case, we’ll need to check that the sums all add to the magic sum.

fn check(grid: &[[u32; SIZE]; SIZE]) -> bool {

Our first check will see if all cells in the grid have been filled out. Because we’ll be filling from the top left to the bottom right, all we have to do is check if the bottom rightmost cell contains a zero.

    if grid[SIZE-1][SIZE-1] == 0 {
        return false;
    }

We can check both diagonals at the same time by maintaining a counter and increasing one column index while decreasing the other.

    let mut sum1 = 0;
    let mut sum2 = 0;
    for i in 0..SIZE {
        sum1 += grid[i][i];
        sum2 += grid[i][SIZE-i-1];
    }
    if sum1 != SUM || sum2 != SUM {
        return false
    }

Checking that each row has the same sum involves iterating over each row, summing them up, and checking.

    for row in grid.iter() {
        let mut total = 0;
        for elt in row.iter() {
            total += elt;
        }
        if total != SUM {
            return false;
        }
    }

Finally, we can do the same thing for columns by enumerating their indices and summing up.

    for row in grid.iter() {
        for (col_idx, _) in row.iter().enumerate() {
            let mut sum = 0;
            for i in 0..SIZE {
                sum += grid[i][col_idx];
            }
            if sum != SUM {
                return false;
            }
        }
    }

    return true;
}

I actually also tried this unrolled version, but it didn’t seem to get results any faster. Maybe the compiler is optimizing away all my for loops? Also, the unrolled version isn’t as generic (it wouldn’t automatically extend to a 4 by 4 grid), but that doesn’t really matter in this case.

fn check(grid: &[[u32; SIZE]; SIZE]) -> bool {
    return
        grid[2][2] + grid[1][1] + grid[0][0] == SUM &&
        grid[0][2] + grid[1][1] + grid[0][2] == SUM &&

        grid[0][0] + grid[0][1] + grid[0][2] == SUM &&
        grid[1][0] + grid[1][1] + grid[1][2] == SUM &&
        grid[2][0] + grid[2][1] + grid[2][2] == SUM &&

        grid[0][0] + grid[1][0] + grid[2][0] == SUM &&
        grid[0][1] + grid[1][1] + grid[2][1] == SUM &&
        grid[0][2] + grid[1][2] + grid[2][2] == SUM;
}

Let’s also define a quick helper function that’ll tell us whether some value already exists in the grid.

fn in_grid(val: u32, grid: &[[u32; SIZE]; SIZE]) -> bool {
    for elt in grid.iter().flat_map(|r| r.iter()) {
        if u32::from(val) == *elt {
            return true
        }
    }
    return false
}

To solve, we want to essentially go through each square in the grid in order and fill it with the next number in sequence. We then make sure that number isn’t already in the grid, and check that the grid is a magic square. If it is, we print it out. If not, we try the next iteration over all possible solutions. Backtracking here is simply reassigning the cell’s value to zero.

fn solve(mut grid: &mut[[u32; SIZE]; SIZE]) {
    let mut row = 0;
    let mut col = 0;
    for i in 0..SIZE.pow(2) {
        row = i / SIZE;
        col = i % SIZE;

        if grid[row][col] == 0 {
            for value in 1 .. SIZE.pow(2) + 1 {
                if !in_grid(value as u32, &grid) {
                    grid[row][col] = value as u32;
                    if check(&grid) {
                        display(&grid);
                    }
                    else {
                        solve(&mut grid);
                    }
                }
            }
            break
        }
    }
    grid[row][col] = 0;
}

Finally, main is simply solving the initial grid.

fn main() {
    solve(&mut initial());
}

We get back all eight 3 by 3 magic squares! This runs pretty fast for me with simple compiler options (just rustc -O magic_square.rs), taking about 0.061s.

2 7 6
9 5 1
4 3 8

2 9 4
7 5 3
6 1 8

4 3 8
9 5 1
2 7 6

4 9 2
3 5 7
8 1 6

6 1 8
7 5 3
2 9 4

6 7 2
1 5 9
8 3 4

8 1 6
3 5 7
4 9 2

8 3 4
1 5 9
6 7 2