Transformer models are foundational to AI systems. There are now countless explanations of “how transformers work?” in the sense of the architecture diagram at the heart of transformers.
However this diagram does not provide any intuition into the computational model of this framework. As researchers become interested in how Transformers work, gaining intuition into their mechanisms becomes increasingly useful.
Thinking like Transformers proposes a computational framework for Transformer-like calculations. The framework uses discrete computation to simulate Transformer computations. The resulting language RASP is a programming language where, ideally, every program can compile down to a specific Transformer (indeed, David Lindner and colleagues have recently released a compiler for a large subset of RASP!).
In this blog post, I reimplemented a variant of RASP in Python (RASPy). The language is roughly compatible with the original version, but with some syntactic changes that I thought were fun. With this language, the author of the work Gail Weiss, provided a challenging set of puzzles to walk through and understand how it works.
#!pip install git+https://github.com/srush/RASPy
Before jumping into the language itself, let’s look at an example of what coding with Transformers looks like. Here is some code that computes the
flip, i.e. reversing an input sequence. The code itself uses two Transformer layers to apply attention and mathematical computations to achieve the result.
def flip(): length = (key(1) == query(1)).value(1) flip = (key(length - indices - 1) == query(indices)).value(tokens) return flip flip()
Our goal is to define a computational formalism that mimics the expressivity of Transformers. We will go through this process by analogy, describing each language construct next to the aspect of the Transformer it represents. (See the full paper for the formal language specification).
The core unit of the language is a sequence operation that transforms a sequence to another sequence of the same length. I will refer to these throughout as transforms.
In a Transformer, the base layer is the input fed to the model. This input usually contains the raw tokens as well as positional information.
In code, the symbol
tokens represents the simplest transform. It returns the tokens passed to the model. The default input is the sequence “hello”.
If we want to change the input to the transform, we use the input method to pass in an alternative.
tokens.input([5, 2, 4, 5, 2, 2])
As with Transformers, we cannot access the positions of these sequences directly. However, to mimic position embeddings, we have access to a sequence of indices.
sop = indices sop.input("goodbye")
After the input layer, we reach the feed-forward network. In a Transformer, this stage can apply mathematical operations to each element of the sequence independently.
In code, we represent this stage by computation on transforms. Mathematical operations are overloaded to represent independent computation on each element of the sequence .
tokens == "l"
The result is a new transform. Once constructed it can be applied to new input.
model = tokens * 2 - 1 model.input([1, 2, 3, 5, 2])
Operations can combine multiple transforms. For example, functions of
indices. The analogy here is that the Transformer activations can keep track of multiple pieces of information simultaneously.
model = tokens - 5 + indices model.input([1, 2, 3, 5, 2])
(tokens == "l") | (indices == 1)
We provide a few helper functions to make it easier to write transforms. For example,
where provides an “if” statement like construct
where((tokens == "h") | (tokens == "l"), tokens, "q")
map lets us define our own operators, for instance a string to int transform. (Users should be careful to only use operations here that could be computed with a simple neural network).
atoi = tokens.map(lambda x: ord(x) - ord('0')) atoi.input("31234")
When chaining these transforms, it is often easier to work with functions. For example the following applies where and then
atoi and then adds 2.
def atoi(seq=tokens): return seq.map(lambda x: ord(x) - ord('0')) op = (atoi(where(tokens == "-", "0", tokens)) + 2) op.input("02-13")
From here on, unless we use a different input sequence, we will assume that the input is ‘hello’ and omit the input display in the illustrations.
Things get more interesting when we start to apply attention. This allows routing of information between the different elements of the sequence.
We begin by defining notation for the keys and queries of the model. Keys and queries are effectively transforms that we will broadcast and compare to each other to create selectors, our parallel to attention patterns. We create them directly from transforms. For example, if we want to define a key, we call
key on a transform.
query. (Queries are presented as columns to reflect their relation to the selectors we will create from them.)
Scalars can be used as keys or queries. They broadcast out to the length of the underlying sequence.
By applying a comparison operation between a key and a query we create a selector, our parallel to an attention matrix - though this one is unweighted.
A selector is a binary matrix indicating which input position (column) each output position (row) will attend to in an eventual attention computation. In the comparison creating it, the key values describe the input (column) positions, and the query values describe the output (row) positions.
eq = (key(tokens) == query(tokens)) eq
offset = (key(indices) == query(indices - 1)) offset
before = key(indices) < query(indices) before
after = key(indices) > query(indices) after
Selectors can be merged using boolean operations. For example, this selector focuses each output position on 1) earlier positions that 2) contain the same original input token as its own. We show this by including both pairs of keys and queries in the matrix.
before & eq
Given an attention selector we can provide a value sequence to aggregate. We represent aggregation by summing up over the values that have a true value for their selector.
(Note: in the original paper, they use a mean aggregation and show a clever construction where mean aggregation is able to represent a sum calculation. RASPy uses sum by default for simplicity and to avoid fractions. In practicce this means that RASPy may underestimate the number of layers needed to convert to a mean based model by a factor of 2.)
Attention aggregation gives us the ability to compute functions like histograms.
(key(tokens) == query(tokens)).value(1)
Visually we follow the architecture diagram. Queries are to the left, Keys at the top, Values at the bottom, and the Output is to the right.
Some attention operations may not even use the input tokens. For instance to compute the
length of a sequence, we create a “select all” attention selector and then add 1 from each position.
length = (key(1) == query(1)).value(1) length = length.name("length") length
Here’s a more complex example, shown step-by-step. (This is the kind of thing they ask in interviews!)
Say we want to compute the sum of neighboring values in a sequence, along a sliding window. First we apply the forward cutoff, attending only to positions that are not too far in the past.
WINDOW=3 s1 = (key(indices) >= query(indices - WINDOW + 1)) s1
Then the backward cutoff, attending only to positions up to and including our own.
s2 = (key(indices) <= query(indices)) s2
sel = s1 & s2 sel
And finally aggregate.
sum2 = sel.value(tokens) sum2.input([1,3,2,2,2])
Here is a simple example that produces a 2-layer transform. The first corresponds to computing length and the second the cumulative sum. The cumulative sum has to go into a second layer because it is applied to a transform which uses length, and so it can only be computed after the computation of length is complete.
def cumsum(seq=tokens): x = (before | (key(indices) == query(indices))).value(seq) return x.name("cumsum") cumsum().input([3, 1, -2, 3, 1])
The language supports building up more complex transforms. It keeps track of the layers by tracking the operations computed so far.
Here is a simple example that produces a 2-layer transform. The first corresponds to computing length and the second the cumulative sum.
x = cumsum(length - indices) x.input([3, 2, 3, 5])