mitchell vitez blog music art media dark mode

LSTM XOR

Full code from this post can be found in this repo

A classic example of a function that doesn’t respond well to linear machine learning algorithms is XOR. Because the meaning of each bit in a sequence depends on other bits, there’s no assignment of weights that we can give to each bit that captures the XOR function. That is to say, there’s no way to come up with weights \(w_1\) and \(w_2\) satisfying these equations, which make up XOR’s truth table:

\[ \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}^\intercal \begin{bmatrix} 0 \\ 0 \end{bmatrix} = 0 , \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}^\intercal \begin{bmatrix} 0 \\ 1 \end{bmatrix} = 1 , \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}^\intercal \begin{bmatrix} 1 \\ 0 \end{bmatrix} = 1 , \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}^\intercal \begin{bmatrix} 1 \\ 1 \end{bmatrix} = 0 \]

If we want to learn XOR, we have to use some nonlinearity to capture its full behavior. One way to do this is with recurrent neural networks, or specifically long short-term memories—LSTMs. The structure of an LSTM allows it to remember disconnected pieces of an input. Think about it as reading in each bit in our XORed sequence one at a time. The LSTM should be able to remember the bits it’s seen, helping us keep track of the overall parity of the sequence. Luckily, this strategy also scales up fairly nicely, allowing us to perform the XOR task on fairly long sequences of bits.

In fact, we can write all the code needed to show off how this works in the span of about thirty lines. We’ll use python, tensorflow, and a bit of numpy to make this happen. First, let’s import a bunch of things we’ll need to use later.

from tensorflow.keras import optimizers
from tensorflow.keras.layers import Dense, Input, LSTM
from tensorflow.keras.models import Sequential
import numpy as np
import random

So far, so good. Just a bunch of machine learny stuff. The most important parts to note are that we’re using a sequential model with some mixture of LSTM and Dense layers. Let’s set up some constants. We’ll operate on sequences of bits with length 50, and let’s create 100,000 of them for our training set.

SEQ_LEN = 50
COUNT = 100000

Our classification will work by being fed in both a sequence and its inverse, and deciding which one has an odd number of 1s. For example, if we see sequences [1 0 1 1] and its inverse, [0 1 0 0], we’ll want the model to predict that the first sequence is the one that has an XOR of 1.

Let’s create a simple lambda bin_pair to help us create inverses of any sequence, by pairing off each bit with its opposite. Our training set consists of these pairs of random bits, in a list of length SEQ_LEN, and we’ll want COUNT training examples overall.

To calculate our target set, we’ll create binary pairs based on the cumulative sum of a sequence, modulo 2. For example, if we have a sequence [1 0 1 1 0 1] the cumulative sum is [1 1 2 3 3 4] and that cumulative sum modulo 2 is [1 1 0 1 1 0]. This sum captures the “current” parity of the sequence as we read it in from left to right. We mostly need the binary pairs so that our target set has the same dimensions as the training set.

bin_pair = lambda x: [x, not(x)]
training = np.array([[bin_pair(random.choice([0, 1])) for _ in range(SEQ_LEN)] for _ in range(COUNT)])
target = np.array([[bin_pair(x) for x in np.cumsum(example[:,0]) % 2] for example in training])

Let’s do a quick check to make sure the dimensions of the two datasets match up. They should both be \(100000 \times 50 \times 2\), since we have 100,000 examples of length 50, and each bit is paired off with its inverse.

print('shape check:', training.shape, '=', target.shape)

Now it’s time to build the model! Sequential means that our network will run each of its layers as a series of steps. We start with taking in the input. Notice that we drop the 100000 in the input shape—this is because we want the dimensions of each individual example here. The next layer is our single-unit LSTM, which should read in sequences for us and act as the network’s “memory”. Finally, we use a dense layer with 2 possibilities (one for each parity) and a softmax activation which turns our answers into probabilities describing the parity of the sequence, given how well the network has been trained.

model = Sequential()
model.add(Input(shape=(SEQ_LEN, 2), dtype='float32'))
model.add(LSTM(1, return_sequences=True))
model.add(Dense(2, activation='softmax'))

Using a binary crossentropy loss helps us make a decision—the XOR’s parity is either 0 or it’s 1. An Adam optimizer is fairly typical and helps adjust the learning rate so that our model can hopefully learn much faster without skipping over too many possible minima. We’ll track how well our model is doing by looking at its accuracy.

Model fitting is the step that takes the longest time. 10 epochs should be enough for our model to gain reasonable amounts of prediction power without spending too long training. A batch size of 128 is pretty usual, and we usually like powers of two there because GPUs handle them better.

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(training, target, epochs=10, batch_size=128)
model.summary()

Now that our model has been trained, we can see how it might work on an arbitrary sequence of bits. Let’s select one at random, and print out what our model predicts, what it thinks the probability is of being correct, and what the actual answer ended up being.

predictions = model.predict(training)
i = random.randint(0, COUNT)
chance = predictions[i,-1,0]
print('randomly selected sequence:', training[i,:,0])
print('prediction:', int(chance > 0.5))
print('confidence: {:0.2f}%'.format((chance if chance > 0.5 else 1 - chance) * 100))
print('actual:', np.sum(training[i,:,0]) % 2)

There’s definitely a lot to take in here, but hopefully this amount of code is small enough that it seems reasonable to just dive right in, start making tweaks, and seeing what happens. Machine learning is super easy to experiment with—all you really need is a computer of some kind—so it makes sense to play around and see how different ways you prod the model produce different results.

This model took under 2 minutes to train on my laptop CPU, and got confidence levels up around 98% on sequences of length 50, and with 100,000 examples in the training set.