This is a version of the tensor puzzles implemented the JAX Penzai library. Available on Github.
Penzai is a really nice fit for these puzzles both because it comes with a really clean visualization library built-in and because it has a very nice named-tensor implementation.
I recommend running in Colab.
!pip install -qqq jaxtyping hypothesis pytest penzai
import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()
# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True, prefers_column=["j"], prefers_row=["i"]))
import inspect
import random
from jaxtyping import Int32
NamedArray = pz.nx.NamedArray
def make_test(name, problem, problem_spec, add_sizes=[],
init_size = {},
constraint=lambda d: d):
args = {}
signature = inspect.signature(problem)
for n, p in signature.parameters.items():
args[n] = [d.name for d in p.annotation.dims]
args["return"] = [d.name for d in signature.return_annotation.dims]
def make_instance():
example = {}
reg = {}
sizes = {}
for k in init_size:
sizes[k] = init_size[k]
for n in args:
size = {}
for name in args[n]:
if name[0] not in sizes:
sizes[name[0]] = random.randint(2, 7)
size[name] = sizes[name[0]]
if "_s" in n:
l = list(size.keys())[0]
example[n] = pz.nx.arange(l, size[l])
else:
v = onp.random.randint(-5, 5, list(size.values()))
example[n] = pz.nx.wrap(v).tag(*args[n])
example = constraint(example)
for n in args:
x = example[n]
x = x.untag(*args[n])
reg[n] = x.unwrap().tolist()
if len(args[n]) == 0:
reg[n] = [0]
return example, reg
examples = []
correct = 0
for i in range(3):
example, reg = make_instance()
# out = example["return"].tolist()
del example["return"]
problem_spec(*reg.values())
if len(reg["return"]) == 1:
reg["return"] = reg["return"][0]
yours = None
yours = problem(**example)
example["target"] = wrap(reg["return"])
example["target"] = example["target"].tag(*args["return"])
if yours is not None:
example["yours"] = yours
same = example["target"] == example["yours"]
if same.untag(*same.named_shape.keys()).unwrap().all():
correct += 1
examples.append(example)
if correct == 3:
print("Correct")
else:
print("Failure")
return examples
contract
, where
and indexing.view
, sum
, take
, squeeze
, tensor
.a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]
[{"a": a, "b":b, "ret": a + b} for a, b in zip(a, b)]
examples = [(wrap([False, True], "i"), wrap([1, 1], "i"), wrap([-1, 0], "i")),
(wrap([[False, True], [True, False]], "i", "j"), wrap([0, 1], "i"), wrap([-1, 0], "j")),
]
[{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples]
def contract(n, *ts):
t = 1
for t2 in ts:
t = t * t2
return pz.nx.nmap(np.sum)(t.untag(n))
a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]
[{"a": a, "b":b, "ret": contract("i", a * b)} for a, b in zip(a, b)]
def ones_spec(i_s, out):
for i in i_s:
out[i] = 1
def ones(i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return i_s * 0 + 1
make_test("one", ones, ones_spec)
Correct
def sum_spec(a, out):
for i in range(len(a)):
out[0] = out[0] + a[i]
def sum(a: Int32[NamedArray, "i"]) -> Int32[NamedArray, ""]:
return contract("i", a)
make_test("sum", sum, sum_spec)
Correct
def outer_spec(a, b, out):
for i in range(len(out)):
for j in range(len(out[0])):
out[i][j] = a[i] * b[j]
def outer(a: Int32[NamedArray, "i"], b : Int32[NamedArray, "j"]) -> Int32[NamedArray, "i j"]:
return a * b
make_test("outer", outer, outer_spec)
Correct
def diag_spec(a, i1_s, out):
for i in range(len(a)):
out[i] = a[i][i]
def diag(a: Int32[NamedArray, "i1 i2"], i1_s: Int32[NamedArray, "i1"]) -> Int32[NamedArray, "i1"]:
return a[{"i1": i1_s, "i2": i1_s}]
make_test("diag", diag, diag_spec)
Correct
def eye_spec(i1_s, i2_s, out):
for i in i1_s:
for j in i2_s:
if i == j:
out[i][j] = 1
else:
out[i][j] = 0
def eye(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
return i1_s == i2_s
make_test("eye", eye, eye_spec)
Correct
def triu_spec(i1_s, i2_s, out):
for i in i1_s:
for j in i2_s:
if i <= j:
out[i][j] = 1
else:
out[i][j] = 0
def triu(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
return i1_s <= i2_s
make_test("triu", triu, triu_spec)
Correct
def cumsum_spec(a, i1_s, i2_s, out):
total = 0
for i in range(len(out)):
out[i] = total + a[i]
total += a[i]
def cumsum(a: Int32[NamedArray, "i1"], i1_s : Int32[NamedArray, "i1"], i2_s: Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
return contract("i1", i1_s <= i2_s, a)
make_test("cumsum", cumsum, cumsum_spec)
Correct
def diff_spec(a, i_s, out):
out[0] = a[0]
for i in range(0, len(out)):
out[i] = a[i] - a[(i - 1)]
def diff(a: Int32[NamedArray, "i"], i1_s : Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return a - a[{"i": i1_s - 1}]
make_test("diff", diff, diff_spec)
Correct
def stack_spec(a, b, j_s, out):
for i in range(len(out[0])):
out[0][i] = a[i]
out[1][i] = b[i]
def stack(a: Int32[NamedArray, "i"], b: Int32[NamedArray, "i"],
j_s: Int32[NamedArray, "j"]) -> Int32[NamedArray, "j i"]:
return where(j_s == 1, b, a)
make_test("stack", stack, stack_spec, init_size={"j" : 2})
Correct
def roll_spec(a, i_s, out):
for i in range(len(out)):
if i + 1 < len(out):
out[i] = a[i + 1]
else:
out[i] = a[i + 1 - len(out)]
def roll(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return a[{"i": (i_s + 1) % i_s.named_shape["i"]}]
make_test("roll", roll, roll_spec)
Correct
def flip_spec(a, i_s, out):
for i in range(len(out)):
out[i] = a[len(out) - i - 1]
def flip(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return a[{"i": -i_s - 1}]
make_test("flip", flip, flip_spec)
Correct
def compress_spec(g, v, i1_s, i2_s, out):
j = 0
for i in range(len(out)):
out[i] = 0
for i in range(len(g)):
if g[i] > 1:
out[j] = v[i]
j += 1
def compress(g: Int32[NamedArray, "i1"], v: Int32[NamedArray, "i2"], i1_s:Int32[NamedArray, "i1"], i2_s:Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
# I don't know how to do this one!
return g
make_test("compress", compress, compress_spec)
Failure
Compute pad_to - eliminate or add 0s to change size of vector.
def pad_to_spec(a, i_s, j_s, out):
for i in range(len(out)):
if i < len(a):
out[i] = a[i]
else:
out[i] = 0
def pad_to(a: Int32[NamedArray, "i"], i_s:Int32[NamedArray, "i"], j_s:Int32[NamedArray, "j"]) -> Int32[NamedArray, "j"]:
return contract("i", a, j_s == i_s)
make_test("pad_to", pad_to, pad_to_spec)
Correct
Compute sequence_mask - pad out to length per batch.
# Didn't do
# def sequence_mask_spec(values, length, out):
# for i in range(len(out)):
# for j in range(len(out[0])):
# if j < length[i]:
# out[i][j] = values[i][j]
# else:
# out[i][j] = 0
# def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
# pass
# def constraint_set_length(d):
# d["length"] = d["length"] % d["values"].shape[1]
# return d
# make_test("sequence_mask",
# sequence_mask, sequence_mask_spec, constraint=constraint_set_length
# )
def bincount_spec(a, i_s, j1_s, j2_s, out):
for i in range(len(out)):
out[i] = 0
for i in range(len(a)):
out[a[i]] += 1
def bincount(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"],
j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
return contract("i", (j1_s == j2_s)[{"j1": a}])
def constraint_set_max(d):
d["a"] = d["a"] % d["return"].named_shape["j2"]
return d
make_test("bincount",
bincount, bincount_spec, constraint=constraint_set_max
)
Correct
Compute scatter_add - add together values that link to the same location.
def scatter_add_spec(values, link, j1_s, j2_s, out):
for i in range(len(out)):
out[i] = 0
for j in range(len(values)):
out[link[j]] += values[j]
def scatter_add(values: Int32[NamedArray, "i"], link: Int32[NamedArray,"i"],
j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
return contract("i", values, (j1_s == j2_s)[{"j1": link}])
def constraint_set_max(d):
d["link"] = d["link"] % d["return"].named_shape["j2"]
return d
make_test("scatter_add",
scatter_add, scatter_add_spec, constraint=constraint_set_max
)
Correct
import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, stack, roll, flip,
compress, pad_to, bincount, scatter_add)
for fn in fns:
lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]
if len(lines) > 3:
print(fn.__name__, len(lines[2]), "(more than 1 line)")
else:
print(fn.__name__, len(lines[1]))
ones 22 sum 27 outer 16 diag 38 eye 36 triu 36 cumsum 55 diff 33 stack 43 roll 53 flip 31 compress 12 pad_to 52 bincount 52 (more than 1 line) scatter_add 63 (more than 1 line)
%%shell
jupyter nbconvert --to html /content/Tensor_Puzzlers_Penzai.ipynb
[NbConvertApp] WARNING | pattern '/content/Tensor_Puzzlers_Penzai.ipynb' matched no files This application is used to convert notebook files (*.ipynb) to various other formats. WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES. Options ======= The options below are convenience aliases to configurable class-options, as listed in the "Equivalent to" description-line of the aliases. To see all configurable class-options for some <cmd>, use: <cmd> --help-all --debug set log level to logging.DEBUG (maximize logging output) Equivalent to: [--Application.log_level=10] --show-config Show the application's configuration (human-readable format) Equivalent to: [--Application.show_config=True] --show-config-json Show the application's configuration (json format) Equivalent to: [--Application.show_config_json=True] --generate-config generate default config file Equivalent to: [--JupyterApp.generate_config=True] -y Answer yes to any questions instead of prompting. Equivalent to: [--JupyterApp.answer_yes=True] --execute Execute the notebook prior to export. Equivalent to: [--ExecutePreprocessor.enabled=True] --allow-errors Continue notebook execution even if one of the cells throws an error and include the error message in the cell output (the default behaviour is to abort conversion). This flag is only relevant if '--execute' was specified, too. Equivalent to: [--ExecutePreprocessor.allow_errors=True] --stdin read a single notebook file from stdin. Write the resulting notebook with default basename 'notebook.*' Equivalent to: [--NbConvertApp.from_stdin=True] --stdout Write notebook output to stdout instead of files. Equivalent to: [--NbConvertApp.writer_class=StdoutWriter] --inplace Run nbconvert in place, overwriting the existing notebook (only relevant when converting to notebook format) Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory=] --clear-output Clear output of current file and save in place, overwriting the existing notebook. Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --ClearOutputPreprocessor.enabled=True] --no-prompt Exclude input and output prompts from converted document. Equivalent to: [--TemplateExporter.exclude_input_prompt=True --TemplateExporter.exclude_output_prompt=True] --no-input Exclude input cells and output prompts from converted document. This mode is ideal for generating code-free reports. Equivalent to: [--TemplateExporter.exclude_output_prompt=True --TemplateExporter.exclude_input=True --TemplateExporter.exclude_input_prompt=True] --allow-chromium-download Whether to allow downloading chromium if no suitable version is found on the system. Equivalent to: [--WebPDFExporter.allow_chromium_download=True] --disable-chromium-sandbox Disable chromium security sandbox when converting to PDF.. Equivalent to: [--WebPDFExporter.disable_sandbox=True] --show-input Shows code input. This flag is only useful for dejavu users. Equivalent to: [--TemplateExporter.exclude_input=False] --embed-images Embed the images as base64 dataurls in the output. This flag is only useful for the HTML/WebPDF/Slides exports. Equivalent to: [--HTMLExporter.embed_images=True] --sanitize-html Whether the HTML in Markdown cells and cell outputs should be sanitized.. Equivalent to: [--HTMLExporter.sanitize_html=True] --log-level=<Enum> Set the log level by value or name. Choices: any of [0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL'] Default: 30 Equivalent to: [--Application.log_level] --config=<Unicode> Full path of a config file. Default: '' Equivalent to: [--JupyterApp.config_file] --to=<Unicode> The export format to be used, either one of the built-in formats ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'rst', 'script', 'slides', 'webpdf'] or a dotted object name that represents the import path for an ``Exporter`` class Default: '' Equivalent to: [--NbConvertApp.export_format] --template=<Unicode> Name of the template to use Default: '' Equivalent to: [--TemplateExporter.template_name] --template-file=<Unicode> Name of the template file to use Default: None Equivalent to: [--TemplateExporter.template_file] --theme=<Unicode> Template specific theme(e.g. the name of a JupyterLab CSS theme distributed as prebuilt extension for the lab template) Default: 'light' Equivalent to: [--HTMLExporter.theme] --sanitize_html=<Bool> Whether the HTML in Markdown cells and cell outputs should be sanitized.This should be set to True by nbviewer or similar tools. Default: False Equivalent to: [--HTMLExporter.sanitize_html] --writer=<DottedObjectName> Writer class used to write the results of the conversion Default: 'FilesWriter' Equivalent to: [--NbConvertApp.writer_class] --post=<DottedOrNone> PostProcessor class used to write the results of the conversion Default: '' Equivalent to: [--NbConvertApp.postprocessor_class] --output=<Unicode> overwrite base name use for output files. can only be used when converting one notebook at a time. Default: '' Equivalent to: [--NbConvertApp.output_base] --output-dir=<Unicode> Directory to write output(s) to. Defaults to output to the directory of each notebook. To recover previous default behaviour (outputting to the current working directory) use . as the flag value. Default: '' Equivalent to: [--FilesWriter.build_directory] --reveal-prefix=<Unicode> The URL prefix for reveal.js (version 3.x). This defaults to the reveal CDN, but can be any url pointing to a copy of reveal.js. For speaker notes to work, this must be a relative path to a local copy of reveal.js: e.g., "reveal.js". If a relative path is given, it must be a subdirectory of the current directory (from which the server is run). See the usage documentation (https://nbconvert.readthedocs.io/en/latest/usage.html#reveal-js-html-slideshow) for more details. Default: '' Equivalent to: [--SlidesExporter.reveal_url_prefix] --nbformat=<Enum> The nbformat version to write. Use this to downgrade notebooks. Choices: any of [1, 2, 3, 4] Default: 4 Equivalent to: [--NotebookExporter.nbformat_version] Examples -------- The simplest way to use nbconvert is > jupyter nbconvert mynotebook.ipynb --to html Options include ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'rst', 'script', 'slides', 'webpdf']. > jupyter nbconvert --to latex mynotebook.ipynb Both HTML and LaTeX support multiple output templates. LaTeX includes 'base', 'article' and 'report'. HTML includes 'basic', 'lab' and 'classic'. You can specify the flavor of the format used. > jupyter nbconvert --to html --template lab mynotebook.ipynb You can also pipe the output to stdout, rather than a file > jupyter nbconvert mynotebook.ipynb --stdout PDF is generated via latex > jupyter nbconvert mynotebook.ipynb --to pdf You can get (and serve) a Reveal.js-powered slideshow > jupyter nbconvert myslides.ipynb --to slides --post serve Multiple notebooks can be given at the command line in a couple of different ways: > jupyter nbconvert notebook*.ipynb > jupyter nbconvert notebook1.ipynb notebook2.ipynb or you can specify the notebooks list in a config file, containing:: c.NbConvertApp.notebooks = ["my_notebook.ipynb"] > jupyter nbconvert --config mycfg.py To see all available configurables, use `--help-all`.
--------------------------------------------------------------------------- CalledProcessError Traceback (most recent call last) <ipython-input-39-6787be2b50f1> in <cell line: 1>() ----> 1 get_ipython().run_cell_magic('shell', '', 'jupyter nbconvert --to html /content/Tensor_Puzzlers_Penzai.ipynb\n') /usr/local/lib/python3.10/dist-packages/google/colab/_shell.py in run_cell_magic(self, magic_name, line, cell) 332 if line and not cell: 333 cell = ' ' --> 334 return super().run_cell_magic(magic_name, line, cell) 335 336 /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell) 2471 with self.builtin_trap: 2472 args = (magic_arg_s, cell) -> 2473 result = fn(*args, **kwargs) 2474 return result 2475 /usr/local/lib/python3.10/dist-packages/google/colab/_system_commands.py in _shell_cell_magic(args, cmd) 110 result = _run_command(cmd, clear_streamed_output=False) 111 if not parsed_args.ignore_errors: --> 112 result.check_returncode() 113 return result 114 /usr/local/lib/python3.10/dist-packages/google/colab/_system_commands.py in check_returncode(self) 135 def check_returncode(self): 136 if self.returncode: --> 137 raise subprocess.CalledProcessError( 138 returncode=self.returncode, cmd=self.args, output=self.output 139 ) CalledProcessError: Command 'jupyter nbconvert --to html /content/Tensor_Puzzlers_Penzai.ipynb ' returned non-zero exit status 255.