Thinking Like Transformers

Transformer models are foundational to AI systems. There are now countless explanations of “how transformers work?” in the sense of the architecture diagram at the heart of transformers.

svg

However this diagram does not provide any intuition into the computational model of this framework. As researchers become interested in how Transformers work, gaining intuition into their mechanisms becomes increasingly useful.

Thinking like Transformers proposes a computational framework for Transformer-like calculations. The framework uses discrete computation to simulate Transformer computations. The resulting language RASP is a programming language where every program compiles down to a specific Transformer.

In this blog post, I reimplemented a variant of RASP in Python (RASPy). The language is roughly compatible with the original version, but with some syntactic changes that I thought were fun. With this language, the author of the work Gail Weiss, provided a challenging set of puzzles to walk through and understand how it works.

#!pip install git+https://github.com/srush/RASPy

Before jumping into the language itself, let’s look at an example of what coding with Transformers looks like. Here is some code that computes the flip, i.e. reversing an input sequence. The code itself uses two Transformer layers to apply attention and mathematical computations to achieve the result.

def flip():
    length = (key(1) == query(1)).value(1)
    flip = (key(length - indices - 1) == query(indices)).value(tokens)
    return flip
flip()
svg

Table of Contents

Transformers as Code

Our goal is to define a computational formalism that mimics the expressivity of Transformers. We will go through this process by analogy, describing each language construct next to the aspect of the Transformer it represents. (See the full paper for the formal language specification).

The core unit of the language is a sequence operation that transforms a sequence to another sequence of the same length. I will refer to these throughout as transforms.

Inputs

In a Transformer, the base layer is the input fed to the model. This input usually contains the raw tokens as well as positional information.

svg

In code, the symbol tokens represents the simplest transform. It returns the tokens passed to the model. The default input is the sequence “hello”.

tokens
svg

If we want to change the input to the transform, we use the input method to pass in an alternative.

tokens.input([5, 2, 4, 5, 2, 2])
svg

As with Transformers, we cannot access the positions of these sequences directly. However, to mimic position embeddings, we have access to a sequence of indices.

indices
svg
sop = indices
sop.input("goodbye")
svg

Feed Forward Network

After the input layer, we reach the feed-forward network. In a Transformer, this stage can apply mathematical operations to each element of the sequence independently.

svg

In code, we represent this stage by computation on transforms. Mathematical operations are overloaded to represent independent computation on each element of the sequence .

tokens == "l"
svg

The result is a new transform. Once constructed it can be applied to new input.

model = tokens * 2  - 1
model.input([1, 2, 3, 5, 2])
svg

Operations can combine multiple transforms. For example, functions of tokens and indices. The analogy here is that the Transformer activations can keep track of multiple pieces of information simultaneously.

model = tokens - 5 + indices
model.input([1, 2, 3, 5, 2])
svg
(tokens == "l") | (indices == 1)
svg

We provide a few helper functions to make it easier to write transforms. For example, where provides an “if” statement like construct

where((tokens == "h") | (tokens == "l"), tokens, "q")
svg

And map lets us define our own operators, for instance a string to int transform. (Users should be careful to only use operations here that could be computed with a simple neural network).

atoi = tokens.map(lambda x: ord(x) - ord('0'))
atoi.input("31234")
svg

When chaining these transforms, it is often easier to write as functions. For example the following applies where and then atoi and then adds 2.

def atoi(seq=tokens):
    return seq.map(lambda x: ord(x) - ord('0')) 

op = (atoi(where(tokens == "-", "0", tokens)) + 2)
op.input("02-13")
svg

Attention Selectors

Things get more interesting when we start to apply attention. This allows routing of information between the different elements of the sequence.

svg

We begin by defining notation for the keys and queries of the model. Keys and Queries can be created directly from the transforms defined above. For example if we want to define a key we call key.

key(tokens)
svg

Similarly for query.

query(tokens)
svg

Scalars can be used as keys or queries. They broadcast out to the length of the underlying sequence.

query(1)
svg

By applying an operation between a keys and queries we create a selector. This corresponds to a binary matrix indicating which keys each query is attending to. Unlike in Transformers, this attention matrix is unweighted.

eq = (key(tokens) == query(tokens))
eq
svg

Some examples:

offset = (key(indices) == query(indices - 1))
offset
svg
before = key(indices) < query(indices)
before
svg
after = key(indices) > query(indices)
after
svg

Selectors can be merged with boolean operations. For example, this selector attends only to tokens before it in time with the same value. We show this by including both pairs of keys and values in the matrix.

before & eq
svg

Using Attention

Given an attention selector we can provide a value sequence to aggregate. We represent aggregation by summing up over the values that have a true value for their selector.

(Note: in the original paper, they use a mean aggregation and show a clever construction where mean aggregation is able to represent a sum calculation. RASPy uses sum by default for simplicity and to avoid fractions. In practicce this means that RASPy may underestimate the number of layers needed to convert to a mean based model by a factor of 2.)

Attention aggregation gives us the ability to compute functions like histograms.

(key(tokens) == query(tokens)).value(1)
svg

Visually we follow the architecture diagram. Queries are to the left, Keys at the top, Values at the bottom, and the Output is to the right.

svg

Some attention operations may not even use the input tokens. For instance to compute the length of a sequence, we create a “select all” attention selector and then add the values.

length = (key(1) == query(1)).value(1)
length = length.name("length")
length
svg

Here’s a more complex example, shown step-by-step. (This is the kind of thing they ask in interviews!)

Say we want to compute the sum of neighboring values in a sequence. First we apply the forward cutoff.

WINDOW=3
s1 = (key(indices) >= query(indices - WINDOW + 1))  
s1
svg

Then the backward cutoff.

s2 = (key(indices) <= query(indices))
s2
svg

Intersect.

sel = s1 & s2
sel
svg

And finally aggregate.

sum2 = sel.value(tokens) 
sum2.input([1,3,2,2,2])
svg

Here’s a similar example with a cumulative sum. We introduce here the ability to name a transform which helps with debugging.

def cumsum(seq=tokens):
    x = (before | (key(indices) == query(indices))).value(seq)
    return x.name("cumsum")
cumsum().input([3, 1, -2, 3, 1])
svg

Layers

The language supports building up more complex transforms. It keeps track of the layers by tracking the operations computed so far.

svg

Here is a simple example that produces a 2-layer transform. The first corresponds to computing length and the second the cumulative sum.

x = cumsum(length - indices)
x.input([3, 2, 3, 5])