Tensor Puzzles - Penzai Edition¶

  • by Sasha Rush - srush_nlp

This is a version of the tensor puzzles implemented the JAX Penzai library. Available on Github.

Penzai is a really nice fit for these puzzles both because it comes with a really clean visualization library built-in and because it has a very nice named-tensor implementation.

I recommend running in Colab.

In [3]:
!pip install -qqq jaxtyping hypothesis pytest penzai
import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True,  prefers_column=["j"], prefers_row=["i"]))
In [4]:
import inspect
import random
from jaxtyping import Int32
NamedArray = pz.nx.NamedArray
def make_test(name, problem, problem_spec, add_sizes=[],
              init_size = {},
              constraint=lambda d: d):
    args = {}
    signature = inspect.signature(problem)
    for n, p in signature.parameters.items():
        args[n] = [d.name for d in p.annotation.dims]
    args["return"] = [d.name for d in signature.return_annotation.dims]

    def make_instance():
        example = {}
        reg = {}
        sizes = {}
        for k in init_size:
            sizes[k] = init_size[k]
        for n in args:
            size = {}
            for name in args[n]:
                if name[0] not in sizes:
                    sizes[name[0]] = random.randint(2, 7)
                size[name] = sizes[name[0]]
            if "_s" in n:
                l = list(size.keys())[0]
                example[n] = pz.nx.arange(l, size[l])
            else:
                v = onp.random.randint(-5, 5, list(size.values()))
                example[n] = pz.nx.wrap(v).tag(*args[n])
        example = constraint(example)
        for n in args:
            x = example[n]
            x = x.untag(*args[n])
            reg[n] = x.unwrap().tolist()
            if len(args[n]) == 0:
                reg[n] = [0]
        return example, reg

    examples = []
    correct = 0
    for i in range(3):
        example, reg = make_instance()
        # out = example["return"].tolist()
        del example["return"]
        problem_spec(*reg.values())
        if len(reg["return"]) == 1:
            reg["return"] = reg["return"][0]
        yours = None
        yours = problem(**example)
        example["target"] = wrap(reg["return"])
        example["target"] = example["target"].tag(*args["return"])
        if yours is not None:
            example["yours"] = yours
        same = example["target"] == example["yours"]
        if same.untag(*same.named_shape.keys()).unwrap().all():
            correct += 1
        examples.append(example)
    if correct == 3:
        print("Correct")
    else:
        print("Failure")
    return examples

Rules¶

  1. Each puzzle needs to be solved in 1 line (<80 columns) of code.
  2. You are only allowed to use contract, where and indexing.
  3. You are not allowed anything else. No view, sum, take, squeeze, tensor.

Example of named infix ops.¶

In [5]:
a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]

[{"a": a, "b":b, "ret": a + b} for a, b in zip(a, b)]
(Loading...)
Out[5]:

Example of where¶

In [6]:
examples = [(wrap([False, True], "i"), wrap([1, 1], "i"), wrap([-1, 0], "i")),
            (wrap([[False, True], [True, False]], "i", "j"), wrap([0, 1], "i"), wrap([-1, 0], "j")),
           ]
[{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples]
(Loading...)
Out[6]:

Example of contraction¶

In [7]:
def contract(n, *ts):
    t = 1
    for t2 in ts:
        t = t * t2
    return pz.nx.nmap(np.sum)(t.untag(n))

a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]

[{"a": a, "b":b, "ret": contract("i", a * b)} for a, b in zip(a, b)]
(Loading...)
Out[7]:

Puzzle 1 - ones¶

Compute ones - the vector of all ones.

In [8]:
def ones_spec(i_s, out):
    for i in i_s:
        out[i] = 1

def ones(i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return i_s * 0 + 1

make_test("one", ones, ones_spec)
Correct
(Loading...)
Out[8]:

Puzzle 2 - sum¶

Compute sum - the sum of a vector.

In [9]:
def sum_spec(a, out):
    for i in range(len(a)):
        out[0] = out[0] + a[i]

def sum(a: Int32[NamedArray, "i"]) -> Int32[NamedArray, ""]:
    return contract("i", a)

make_test("sum", sum, sum_spec)
Correct
(Loading...)
Out[9]:

Puzzle 3 - outer¶

Compute outer - the outer product of two vectors.

In [10]:
def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]

def outer(a: Int32[NamedArray, "i"], b : Int32[NamedArray, "j"]) -> Int32[NamedArray, "i j"]:
    return a * b

make_test("outer", outer, outer_spec)
Correct
(Loading...)
Out[10]:

Puzzle 4 - diag¶

Compute diag - the diagonal vector of a square matrix.

In [11]:
def diag_spec(a, i1_s, out):
    for i in range(len(a)):
        out[i] = a[i][i]

def diag(a: Int32[NamedArray, "i1 i2"], i1_s: Int32[NamedArray, "i1"]) -> Int32[NamedArray, "i1"]:
    return a[{"i1": i1_s, "i2": i1_s}]


make_test("diag", diag, diag_spec)
Correct
(Loading...)
Out[11]:

Puzzle 5 - eye¶

Compute eye - the identity matrix.

In [25]:
def eye_spec(i1_s, i2_s, out):
    for i in i1_s:
        for j in i2_s:
            if i == j:
                out[i][j] = 1
            else:
                out[i][j] = 0

def eye(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
    return i1_s == i2_s


make_test("eye", eye, eye_spec)
Correct
(Loading...)
Out[25]:

Puzzle 6 - triu¶

Compute triu - the upper triangular matrix.

In [26]:
def triu_spec(i1_s, i2_s, out):
    for i in i1_s:
        for j in i2_s:
            if i <= j:
                out[i][j] = 1
            else:
                out[i][j] = 0

def triu(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
    return i1_s <= i2_s


make_test("triu", triu, triu_spec)
Correct
(Loading...)
Out[26]:

Puzzle 7 - cumsum¶

Compute cumsum - the cumulative sum.

In [27]:
def cumsum_spec(a, i1_s, i2_s, out):
    total = 0
    for i in range(len(out)):
        out[i] = total + a[i]
        total += a[i]

def cumsum(a: Int32[NamedArray, "i1"], i1_s : Int32[NamedArray, "i1"], i2_s: Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
    return contract("i1", i1_s <= i2_s, a)

make_test("cumsum", cumsum, cumsum_spec)
Correct
(Loading...)
Out[27]:

Puzzle 8 - diff¶

Compute diff - the running difference.

In [28]:
def diff_spec(a, i_s, out):
    out[0] = a[0]
    for i in range(0, len(out)):
        out[i] = a[i] - a[(i - 1)]

def diff(a: Int32[NamedArray, "i"], i1_s : Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a - a[{"i": i1_s - 1}]

make_test("diff", diff, diff_spec)
Correct
(Loading...)
Out[28]:

Puzzle 9 - stack¶

Compute vstack - the matrix of two vectors

In [30]:
def stack_spec(a, b, j_s, out):
    for i in range(len(out[0])):
        out[0][i] = a[i]
        out[1][i] = b[i]

def stack(a: Int32[NamedArray, "i"], b: Int32[NamedArray, "i"],
          j_s: Int32[NamedArray, "j"]) -> Int32[NamedArray, "j i"]:
    return where(j_s == 1, b, a)


make_test("stack", stack, stack_spec, init_size={"j" : 2})
Correct
(Loading...)
Out[30]:

Puzzle 10 - roll¶

Compute roll - the vector shifted 1 circular position.

In [17]:
def roll_spec(a, i_s, out):
    for i in range(len(out)):
        if i + 1 < len(out):
            out[i] = a[i + 1]
        else:
            out[i] = a[i + 1 - len(out)]

def roll(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a[{"i": (i_s + 1) % i_s.named_shape["i"]}]


make_test("roll", roll, roll_spec)
Correct
(Loading...)
Out[17]:

Puzzle 11 - flip¶

Compute flip - the reversed vector

In [32]:
def flip_spec(a, i_s, out):
    for i in range(len(out)):
        out[i] = a[len(out) - i - 1]

def flip(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a[{"i": -i_s - 1}]

make_test("flip", flip, flip_spec)
Correct
(Loading...)
Out[32]:

Puzzle 12 - compress¶

Compute compress - keep only masked entries (left-aligned).

In [19]:
def compress_spec(g, v, i1_s, i2_s, out):
    j = 0
    for i in range(len(out)):
        out[i] = 0
    for i in range(len(g)):
        if g[i] > 1:
            out[j] = v[i]
            j += 1

def compress(g: Int32[NamedArray, "i1"], v: Int32[NamedArray, "i2"], i1_s:Int32[NamedArray, "i1"], i2_s:Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
    # I don't know how to do this one!
    return g

make_test("compress", compress, compress_spec)
Failure
(Loading...)
Out[19]:

Puzzle 13 - pad_to¶

Compute pad_to - eliminate or add 0s to change size of vector.

In [33]:
def pad_to_spec(a, i_s, j_s, out):
    for i in range(len(out)):
        if i < len(a):
            out[i] = a[i]
        else:
            out[i] = 0

def pad_to(a: Int32[NamedArray, "i"], i_s:Int32[NamedArray, "i"], j_s:Int32[NamedArray, "j"])  -> Int32[NamedArray, "j"]:
    return contract("i", a, j_s == i_s)


make_test("pad_to", pad_to, pad_to_spec)
Correct
(Loading...)
Out[33]:

Puzzle 14 - sequence_mask¶

Compute sequence_mask - pad out to length per batch.

In [21]:
# Didn't do
# def sequence_mask_spec(values, length, out):
#     for i in range(len(out)):
#         for j in range(len(out[0])):
#             if j < length[i]:
#                 out[i][j] = values[i][j]
#             else:
#                 out[i][j] = 0

# def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
#     pass


# def constraint_set_length(d):
#     d["length"] = d["length"] % d["values"].shape[1]
#     return d

# make_test("sequence_mask",
#     sequence_mask, sequence_mask_spec, constraint=constraint_set_length
# )

Puzzle 15 - bincount¶

Compute bincount - count number of times an entry was seen.

In [35]:
def bincount_spec(a, i_s, j1_s, j2_s, out):
    for i in range(len(out)):
        out[i] = 0
    for i in range(len(a)):
        out[a[i]] += 1

def bincount(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"],
             j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
    return contract("i", (j1_s == j2_s)[{"j1": a}])


def constraint_set_max(d):
    d["a"] = d["a"] % d["return"].named_shape["j2"]
    return d


make_test("bincount",
    bincount, bincount_spec, constraint=constraint_set_max
)
Correct
(Loading...)
Out[35]:

Puzzle 16 - scatter_add¶

Compute scatter_add - add together values that link to the same location.

In [36]:
def scatter_add_spec(values, link, j1_s, j2_s, out):
    for i in range(len(out)):
        out[i] = 0
    for j in range(len(values)):
        out[link[j]] += values[j]

def scatter_add(values: Int32[NamedArray, "i"], link: Int32[NamedArray,"i"],
                j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
    return contract("i", values, (j1_s == j2_s)[{"j1": link}])


def constraint_set_max(d):
    d["link"] = d["link"] % d["return"].named_shape["j2"]
    return d

make_test("scatter_add",
    scatter_add, scatter_add_spec, constraint=constraint_set_max
)
Correct
(Loading...)
Out[36]:
In [24]:
import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, stack, roll, flip,
       compress, pad_to,  bincount, scatter_add)

for fn in fns:
    lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]

    if len(lines) > 3:
        print(fn.__name__, len(lines[2]), "(more than 1 line)")
    else:
        print(fn.__name__, len(lines[1]))
ones 22
sum 27
outer 16
diag 38
eye 36
triu 36
cumsum 55
diff 33
stack 43
roll 53
flip 31
compress 12
pad_to 52
bincount 52 (more than 1 line)
scatter_add 63 (more than 1 line)
In [39]:
%%shell
jupyter nbconvert --to html /content/Tensor_Puzzlers_Penzai.ipynb
[NbConvertApp] WARNING | pattern '/content/Tensor_Puzzlers_Penzai.ipynb' matched no files
This application is used to convert notebook files (*.ipynb)
        to various other formats.

        WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES.

Options
=======
The options below are convenience aliases to configurable class-options,
as listed in the "Equivalent to" description-line of the aliases.
To see all configurable class-options for some <cmd>, use:
    <cmd> --help-all

--debug
    set log level to logging.DEBUG (maximize logging output)
    Equivalent to: [--Application.log_level=10]
--show-config
    Show the application's configuration (human-readable format)
    Equivalent to: [--Application.show_config=True]
--show-config-json
    Show the application's configuration (json format)
    Equivalent to: [--Application.show_config_json=True]
--generate-config
    generate default config file
    Equivalent to: [--JupyterApp.generate_config=True]
-y
    Answer yes to any questions instead of prompting.
    Equivalent to: [--JupyterApp.answer_yes=True]
--execute
    Execute the notebook prior to export.
    Equivalent to: [--ExecutePreprocessor.enabled=True]
--allow-errors
    Continue notebook execution even if one of the cells throws an error and include the error message in the cell output (the default behaviour is to abort conversion). This flag is only relevant if '--execute' was specified, too.
    Equivalent to: [--ExecutePreprocessor.allow_errors=True]
--stdin
    read a single notebook file from stdin. Write the resulting notebook with default basename 'notebook.*'
    Equivalent to: [--NbConvertApp.from_stdin=True]
--stdout
    Write notebook output to stdout instead of files.
    Equivalent to: [--NbConvertApp.writer_class=StdoutWriter]
--inplace
    Run nbconvert in place, overwriting the existing notebook (only
            relevant when converting to notebook format)
    Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory=]
--clear-output
    Clear output of current file and save in place,
            overwriting the existing notebook.
    Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --ClearOutputPreprocessor.enabled=True]
--no-prompt
    Exclude input and output prompts from converted document.
    Equivalent to: [--TemplateExporter.exclude_input_prompt=True --TemplateExporter.exclude_output_prompt=True]
--no-input
    Exclude input cells and output prompts from converted document.
            This mode is ideal for generating code-free reports.
    Equivalent to: [--TemplateExporter.exclude_output_prompt=True --TemplateExporter.exclude_input=True --TemplateExporter.exclude_input_prompt=True]
--allow-chromium-download
    Whether to allow downloading chromium if no suitable version is found on the system.
    Equivalent to: [--WebPDFExporter.allow_chromium_download=True]
--disable-chromium-sandbox
    Disable chromium security sandbox when converting to PDF..
    Equivalent to: [--WebPDFExporter.disable_sandbox=True]
--show-input
    Shows code input. This flag is only useful for dejavu users.
    Equivalent to: [--TemplateExporter.exclude_input=False]
--embed-images
    Embed the images as base64 dataurls in the output. This flag is only useful for the HTML/WebPDF/Slides exports.
    Equivalent to: [--HTMLExporter.embed_images=True]
--sanitize-html
    Whether the HTML in Markdown cells and cell outputs should be sanitized..
    Equivalent to: [--HTMLExporter.sanitize_html=True]
--log-level=<Enum>
    Set the log level by value or name.
    Choices: any of [0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL']
    Default: 30
    Equivalent to: [--Application.log_level]
--config=<Unicode>
    Full path of a config file.
    Default: ''
    Equivalent to: [--JupyterApp.config_file]
--to=<Unicode>
    The export format to be used, either one of the built-in formats
            ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'rst', 'script', 'slides', 'webpdf']
            or a dotted object name that represents the import path for an
            ``Exporter`` class
    Default: ''
    Equivalent to: [--NbConvertApp.export_format]
--template=<Unicode>
    Name of the template to use
    Default: ''
    Equivalent to: [--TemplateExporter.template_name]
--template-file=<Unicode>
    Name of the template file to use
    Default: None
    Equivalent to: [--TemplateExporter.template_file]
--theme=<Unicode>
    Template specific theme(e.g. the name of a JupyterLab CSS theme distributed
    as prebuilt extension for the lab template)
    Default: 'light'
    Equivalent to: [--HTMLExporter.theme]
--sanitize_html=<Bool>
    Whether the HTML in Markdown cells and cell outputs should be sanitized.This
    should be set to True by nbviewer or similar tools.
    Default: False
    Equivalent to: [--HTMLExporter.sanitize_html]
--writer=<DottedObjectName>
    Writer class used to write the
                                        results of the conversion
    Default: 'FilesWriter'
    Equivalent to: [--NbConvertApp.writer_class]
--post=<DottedOrNone>
    PostProcessor class used to write the
                                        results of the conversion
    Default: ''
    Equivalent to: [--NbConvertApp.postprocessor_class]
--output=<Unicode>
    overwrite base name use for output files.
                can only be used when converting one notebook at a time.
    Default: ''
    Equivalent to: [--NbConvertApp.output_base]
--output-dir=<Unicode>
    Directory to write output(s) to. Defaults
                                  to output to the directory of each notebook. To recover
                                  previous default behaviour (outputting to the current
                                  working directory) use . as the flag value.
    Default: ''
    Equivalent to: [--FilesWriter.build_directory]
--reveal-prefix=<Unicode>
    The URL prefix for reveal.js (version 3.x).
            This defaults to the reveal CDN, but can be any url pointing to a copy
            of reveal.js.
            For speaker notes to work, this must be a relative path to a local
            copy of reveal.js: e.g., "reveal.js".
            If a relative path is given, it must be a subdirectory of the
            current directory (from which the server is run).
            See the usage documentation
            (https://nbconvert.readthedocs.io/en/latest/usage.html#reveal-js-html-slideshow)
            for more details.
    Default: ''
    Equivalent to: [--SlidesExporter.reveal_url_prefix]
--nbformat=<Enum>
    The nbformat version to write.
            Use this to downgrade notebooks.
    Choices: any of [1, 2, 3, 4]
    Default: 4
    Equivalent to: [--NotebookExporter.nbformat_version]

Examples
--------

    The simplest way to use nbconvert is

            > jupyter nbconvert mynotebook.ipynb --to html

            Options include ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'rst', 'script', 'slides', 'webpdf'].

            > jupyter nbconvert --to latex mynotebook.ipynb

            Both HTML and LaTeX support multiple output templates. LaTeX includes
            'base', 'article' and 'report'.  HTML includes 'basic', 'lab' and
            'classic'. You can specify the flavor of the format used.

            > jupyter nbconvert --to html --template lab mynotebook.ipynb

            You can also pipe the output to stdout, rather than a file

            > jupyter nbconvert mynotebook.ipynb --stdout

            PDF is generated via latex

            > jupyter nbconvert mynotebook.ipynb --to pdf

            You can get (and serve) a Reveal.js-powered slideshow

            > jupyter nbconvert myslides.ipynb --to slides --post serve

            Multiple notebooks can be given at the command line in a couple of
            different ways:

            > jupyter nbconvert notebook*.ipynb
            > jupyter nbconvert notebook1.ipynb notebook2.ipynb

            or you can specify the notebooks list in a config file, containing::

                c.NbConvertApp.notebooks = ["my_notebook.ipynb"]

            > jupyter nbconvert --config mycfg.py

To see all available configurables, use `--help-all`.

---------------------------------------------------------------------------
CalledProcessError                        Traceback (most recent call last)
<ipython-input-39-6787be2b50f1> in <cell line: 1>()
----> 1 get_ipython().run_cell_magic('shell', '', 'jupyter nbconvert --to html /content/Tensor_Puzzlers_Penzai.ipynb\n')

/usr/local/lib/python3.10/dist-packages/google/colab/_shell.py in run_cell_magic(self, magic_name, line, cell)
    332     if line and not cell:
    333       cell = ' '
--> 334     return super().run_cell_magic(magic_name, line, cell)
    335 
    336 

/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
   2471             with self.builtin_trap:
   2472                 args = (magic_arg_s, cell)
-> 2473                 result = fn(*args, **kwargs)
   2474             return result
   2475 

/usr/local/lib/python3.10/dist-packages/google/colab/_system_commands.py in _shell_cell_magic(args, cmd)
    110   result = _run_command(cmd, clear_streamed_output=False)
    111   if not parsed_args.ignore_errors:
--> 112     result.check_returncode()
    113   return result
    114 

/usr/local/lib/python3.10/dist-packages/google/colab/_system_commands.py in check_returncode(self)
    135   def check_returncode(self):
    136     if self.returncode:
--> 137       raise subprocess.CalledProcessError(
    138           returncode=self.returncode, cmd=self.args, output=self.output
    139       )

CalledProcessError: Command 'jupyter nbconvert --to html /content/Tensor_Puzzlers_Penzai.ipynb
' returned non-zero exit status 255.