Transformer models are foundational to AI systems. There are now countless explanations of “how transformers work?” in the sense of the architecture diagram at the heart of transformers.
However this diagram does not provide any intuition into the computational model of this framework. As researchers become interested in how Transformers work, gaining intuition into their mechanisms becomes increasingly useful.
Thinking like Transformers proposes a computational framework for Transformer-like calculations. The framework uses discrete computation to simulate Transformer computations. The resulting language RASP is a programming language where every program compiles down to a specific Transformer.
In this blog post, I reimplemented a variant of RASP in Python (RASPy). The language is roughly compatible with the original version, but with some syntactic changes that I thought were fun. With this language, the author of the work Gail Weiss, provided a challenging set of puzzles to walk through and understand how it works.
#!pip install git+https://github.com/srush/RASPy
Before jumping into the language itself, let’s look at an example of what coding with Transformers looks like. Here is some code that computes the
flip, i.e. reversing an input sequence. The code itself uses two Transformer layers to apply attention and mathematical computations to achieve the result.
def flip(): length = (key(1) == query(1)).value(1) flip = (key(length - indices - 1) == query(indices)).value(tokens) return flip flip()
Our goal is to define a computational formalism that mimics the expressivity of Transformers. We will go through this process by analogy, describing each language construct next to the aspect of the Transformer it represents. (See the full paper for the formal language specification).
The core unit of the language is a sequence operation that transforms a sequence to another sequence of the same length. I will refer to these throughout as transforms.
In a Transformer, the base layer is the input fed to the model. This input usually contains the raw tokens as well as positional information.
In code, the symbol
tokens represents the simplest transform. It returns the tokens passed to the model. The default input is the sequence “hello”.
If we want to change the input to the transform, we use the input method to pass in an alternative.
tokens.input([5, 2, 4, 5, 2, 2])
As with Transformers, we cannot access the positions of these sequences directly. However, to mimic position embeddings, we have access to a sequence of indices.
sop = indices sop.input("goodbye")
After the input layer, we reach the feed-forward network. In a Transformer, this stage can apply mathematical operations to each element of the sequence independently.
In code, we represent this stage by computation on transforms. Mathematical operations are overloaded to represent independent computation on each element of the sequence .
tokens == "l"
The result is a new transform. Once constructed it can be applied to new input.
model = tokens * 2 - 1 model.input([1, 2, 3, 5, 2])
Operations can combine multiple transforms. For example, functions of
indices. The analogy here is that the Transformer activations can keep track of multiple pieces of information simultaneously.
model = tokens - 5 + indices model.input([1, 2, 3, 5, 2])
(tokens == "l") | (indices == 1)
We provide a few helper functions to make it easier to write transforms. For example,
where provides an “if” statement like construct
where((tokens == "h") | (tokens == "l"), tokens, "q")
map lets us define our own operators, for instance a string to int transform. (Users should be careful to only use operations here that could be computed with a simple neural network).
atoi = tokens.map(lambda x: ord(x) - ord('0')) atoi.input("31234")
When chaining these transforms, it is often easier to write as functions. For example the following applies where and then
atoi and then adds 2.
def atoi(seq=tokens): return seq.map(lambda x: ord(x) - ord('0')) op = (atoi(where(tokens == "-", "0", tokens)) + 2) op.input("02-13")
Things get more interesting when we start to apply attention. This allows routing of information between the different elements of the sequence.
We begin by defining notation for the keys and queries of the model. Keys and Queries can be created directly from the transforms defined above. For example if we want to define a key we call
Scalars can be used as keys or queries. They broadcast out to the length of the underlying sequence.
By applying an operation between a keys and queries we create a selector. This corresponds to a binary matrix indicating which keys each query is attending to. Unlike in Transformers, this attention matrix is unweighted.
eq = (key(tokens) == query(tokens)) eq
offset = (key(indices) == query(indices - 1)) offset
before = key(indices) < query(indices) before
after = key(indices) > query(indices) after
Selectors can be merged with boolean operations. For example, this selector attends only to tokens before it in time with the same value. We show this by including both pairs of keys and values in the matrix.
before & eq
Given an attention selector we can provide a value sequence to aggregate. We represent aggregation by summing up over the values that have a true value for their selector.
(Note: in the original paper, they use a mean aggregation and show a clever construction where mean aggregation is able to represent a sum calculation. RASPy uses sum by default for simplicity and to avoid fractions. In practicce this means that RASPy may underestimate the number of layers needed to convert to a mean based model by a factor of 2.)
Attention aggregation gives us the ability to compute functions like histograms.
(key(tokens) == query(tokens)).value(1)
Visually we follow the architecture diagram. Queries are to the left, Keys at the top, Values at the bottom, and the Output is to the right.
Some attention operations may not even use the input tokens. For instance to compute the
length of a sequence, we create a “select all” attention selector and then add the values.
length = (key(1) == query(1)).value(1) length = length.name("length") length
Here’s a more complex example, shown step-by-step. (This is the kind of thing they ask in interviews!)
Say we want to compute the sum of neighboring values in a sequence. First we apply the forward cutoff.
WINDOW=3 s1 = (key(indices) >= query(indices - WINDOW + 1)) s1
Then the backward cutoff.
s2 = (key(indices) <= query(indices)) s2
sel = s1 & s2 sel
And finally aggregate.
sum2 = sel.value(tokens) sum2.input([1,3,2,2,2])
Here’s a similar example with a cumulative sum. We introduce here the ability to
name a transform which helps with debugging.
def cumsum(seq=tokens): x = (before | (key(indices) == query(indices))).value(seq) return x.name("cumsum") cumsum().input([3, 1, -2, 3, 1])
The language supports building up more complex transforms. It keeps track of the layers by tracking the operations computed so far.
Here is a simple example that produces a 2-layer transform. The first corresponds to computing length and the second the cumulative sum.
x = cumsum(length - indices) x.input([3, 2, 3, 5])