541 lines
17 KiB
Python
541 lines
17 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Wadler-Lindig pretty printer.
|
|
#
|
|
# References:
|
|
# Wadler, P., 1998. A prettier printer. Journal of Functional Programming,
|
|
# pp.223-244.
|
|
#
|
|
# Lindig, C. 2000. Strictly Pretty.
|
|
# https://lindig.github.io/papers/strictly-pretty-2000.pdf
|
|
#
|
|
# Hafiz, A. 2021. Strictly Annotated: A Pretty-Printer With Support for
|
|
# Annotations. https://ayazhafiz.com/articles/21/strictly-annotated
|
|
#
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
import enum
|
|
from functools import partial
|
|
import sys
|
|
from typing import Any, NamedTuple, TYPE_CHECKING
|
|
|
|
from jax._src import config
|
|
from jax._src import util
|
|
from jax._src.lib import _pretty_printer as _pretty_printer
|
|
|
|
|
|
_PPRINT_USE_COLOR = config.bool_state(
|
|
'jax_pprint_use_color',
|
|
True,
|
|
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
|
)
|
|
|
|
def _can_use_color() -> bool:
|
|
try:
|
|
# Check if we're in IPython or Colab
|
|
ipython = get_ipython() # type: ignore[name-defined]
|
|
shell = ipython.__class__.__name__
|
|
if shell == "ZMQInteractiveShell":
|
|
# Jupyter Notebook
|
|
return True
|
|
elif "colab" in str(ipython.__class__):
|
|
# Google Colab (external or internal)
|
|
return True
|
|
except NameError:
|
|
pass
|
|
# Otherwise check if we're in a terminal
|
|
return hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
|
|
|
|
CAN_USE_COLOR = _can_use_color()
|
|
|
|
# TODO(phawkins): remove this condition after the jaxlib 0.6.3 release.
|
|
if TYPE_CHECKING or _pretty_printer is None:
|
|
try:
|
|
import colorama # pytype: disable=import-error
|
|
except ImportError:
|
|
colorama = None
|
|
|
|
class Doc(util.StrictABC):
|
|
__slots__ = ()
|
|
|
|
def format(
|
|
self, width: int = 80, *, use_color: bool | None = None,
|
|
annotation_prefix: str = " # ",
|
|
source_map: list[list[tuple[int, int, Any]]] | None = None
|
|
) -> str:
|
|
"""
|
|
Formats a pretty-printer document as a string.
|
|
|
|
Args:
|
|
source_map: for each line in the output, contains a list of
|
|
(start column, end column, source) tuples. Each tuple associates a
|
|
region of output text with a source.
|
|
"""
|
|
if use_color is None:
|
|
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
|
|
return _format(self, width, use_color=use_color,
|
|
annotation_prefix=annotation_prefix, source_map=source_map)
|
|
|
|
def __str__(self):
|
|
return self.format()
|
|
|
|
def __add__(self, other: Doc) -> Doc:
|
|
return concat([self, other])
|
|
|
|
def num_annotations(self) -> int:
|
|
raise NotImplementedError()
|
|
|
|
class _NilDoc(Doc):
|
|
def __repr__(self): return "nil"
|
|
|
|
def num_annotations(self) -> int:
|
|
return 0
|
|
|
|
_nil = _NilDoc()
|
|
|
|
class _TextDoc(Doc):
|
|
__slots__ = ("text", "annotation")
|
|
text: str
|
|
annotation: str | None
|
|
|
|
def __init__(self, text: str, annotation: str | None = None):
|
|
assert isinstance(text, str), text
|
|
assert annotation is None or isinstance(annotation, str), annotation
|
|
self.text = text
|
|
self.annotation = annotation
|
|
|
|
def __repr__(self):
|
|
if self.annotation is not None:
|
|
return f"text(\"{self.text}\", annotation=\"{self.annotation}\")"
|
|
else:
|
|
return f"text(\"{self.text}\")"
|
|
|
|
def num_annotations(self) -> int:
|
|
return 1 if self.annotation is not None else 0
|
|
|
|
class _ConcatDoc(Doc):
|
|
__slots__ = ("children", "_num_annotations")
|
|
children: list[Doc]
|
|
_num_annotations: int
|
|
|
|
def __init__(self, children: Sequence[Doc]):
|
|
self.children = list(children)
|
|
self._num_annotations = sum(child.num_annotations() for child in children)
|
|
|
|
def __repr__(self): return f"concat({self.children})"
|
|
|
|
def num_annotations(self) -> int:
|
|
return self._num_annotations
|
|
|
|
class _BreakDoc(Doc):
|
|
__slots__ = ("text",)
|
|
text: str
|
|
|
|
def __init__(self, text: str):
|
|
assert isinstance(text, str), text
|
|
self.text = text
|
|
|
|
def __repr__(self): return f"break({self.text})"
|
|
|
|
def num_annotations(self) -> int:
|
|
return 0
|
|
|
|
class _GroupDoc(Doc):
|
|
__slots__ = ("child",)
|
|
child: Doc
|
|
|
|
def __init__(self, child: Doc):
|
|
assert isinstance(child, Doc), child
|
|
self.child = child
|
|
|
|
def __repr__(self): return f"group({self.child})"
|
|
|
|
def num_annotations(self) -> int:
|
|
return self.child.num_annotations()
|
|
|
|
class _NestDoc(Doc):
|
|
__slots__ = ("n", "child",)
|
|
n: int
|
|
child: Doc
|
|
|
|
def __init__(self, n: int, child: Doc):
|
|
assert isinstance(child, Doc), child
|
|
self.n = n
|
|
self.child = child
|
|
|
|
def __repr__(self): return f"nest({self.n, self.child})"
|
|
|
|
def num_annotations(self) -> int:
|
|
return self.child.num_annotations()
|
|
|
|
_NO_SOURCE = object()
|
|
|
|
class _SourceMapDoc(Doc):
|
|
__slots__ = ("child", "source")
|
|
child: Doc
|
|
source: Any
|
|
|
|
def __init__(self, child: Doc, source: Any):
|
|
assert isinstance(child, Doc), child
|
|
self.child = child
|
|
self.source = source
|
|
|
|
def __repr__(self): return f"source({self.child}, {self.source})"
|
|
|
|
def num_annotations(self) -> int:
|
|
return self.child.num_annotations()
|
|
|
|
Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE",
|
|
"MAGENTA", "CYAN", "WHITE", "RESET"])
|
|
Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"])
|
|
|
|
class _ColorDoc(Doc):
|
|
__slots__ = ("foreground", "background", "intensity", "child")
|
|
foreground: Color | None
|
|
background: Color | None
|
|
intensity: Intensity | None
|
|
child: Doc
|
|
|
|
def __init__(self, child: Doc, *, foreground: Color | None = None,
|
|
background: Color | None = None,
|
|
intensity: Intensity | None = None):
|
|
assert isinstance(child, Doc), child
|
|
self.child = child
|
|
self.foreground = foreground
|
|
self.background = background
|
|
self.intensity = intensity
|
|
|
|
def num_annotations(self) -> int:
|
|
return self.child.num_annotations()
|
|
|
|
_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"])
|
|
|
|
|
|
# In Lindig's paper fits() and format() are defined recursively. This is a
|
|
# non-recursive formulation using an explicit stack, necessary because Python
|
|
# doesn't have a tail recursion optimization.
|
|
|
|
def _fits(doc: Doc, width: int) -> bool:
|
|
agenda = [doc]
|
|
while width >= 0 and len(agenda) > 0:
|
|
doc = agenda.pop()
|
|
if isinstance(doc, _NilDoc):
|
|
pass
|
|
elif isinstance(doc, _TextDoc):
|
|
width -= len(doc.text)
|
|
elif isinstance(doc, _ConcatDoc):
|
|
agenda.extend(reversed(doc.children))
|
|
elif isinstance(doc, _BreakDoc):
|
|
width -= len(doc.text)
|
|
elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)):
|
|
agenda.append(doc.child)
|
|
else:
|
|
raise ValueError("Invalid document ", doc)
|
|
|
|
return width >= 0
|
|
|
|
|
|
# Annotation layout: A flat group is sparse if there are no breaks between
|
|
# annotations.
|
|
def _sparse(doc: Doc) -> bool:
|
|
agenda = [doc]
|
|
if doc.num_annotations() == 0:
|
|
return True
|
|
num_annotations = 0
|
|
seen_break = False
|
|
while len(agenda) > 0:
|
|
doc = agenda.pop()
|
|
if isinstance(doc, _NilDoc):
|
|
pass
|
|
elif isinstance(doc, _TextDoc):
|
|
if doc.annotation is not None:
|
|
if num_annotations >= 1 and seen_break:
|
|
return False
|
|
num_annotations += 1
|
|
elif isinstance(doc, _ConcatDoc):
|
|
agenda.extend(reversed(doc.children))
|
|
elif isinstance(doc, _BreakDoc):
|
|
seen_break = True
|
|
elif isinstance(doc, _NestDoc):
|
|
agenda.append(doc.child)
|
|
elif isinstance(doc, _GroupDoc):
|
|
agenda.append(doc.child)
|
|
elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc):
|
|
agenda.append(doc.child)
|
|
else:
|
|
raise ValueError("Invalid document ", doc)
|
|
|
|
return True
|
|
|
|
class _ColorState(NamedTuple):
|
|
foreground: Color
|
|
background: Color
|
|
intensity: Intensity
|
|
|
|
class _State(NamedTuple):
|
|
indent: int
|
|
mode: _BreakMode
|
|
doc: Doc
|
|
color: _ColorState
|
|
source_map: Any
|
|
|
|
class _Line(NamedTuple):
|
|
text: str
|
|
width: int
|
|
annotations: list[str]
|
|
|
|
|
|
def _update_color(use_color: bool, state: _ColorState, update: _ColorState
|
|
) -> tuple[_ColorState, str]:
|
|
if not use_color or colorama is None:
|
|
return update, ""
|
|
color_str = ""
|
|
if state.foreground != update.foreground:
|
|
color_str += getattr(colorama.Fore, str(update.foreground.name))
|
|
if state.background != update.background:
|
|
color_str += getattr(colorama.Back, str(update.background.name))
|
|
if state.intensity != update.intensity:
|
|
color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands
|
|
color_str += getattr(colorama.Style, str(update.intensity.name))
|
|
return update, color_str
|
|
|
|
|
|
def _align_annotations(lines: list[_Line], annotation_prefix: str) -> list[str]:
|
|
# TODO: Hafiz also implements a local alignment mode, where groups of lines
|
|
# with annotations are aligned together.
|
|
maxlen = max(l.width for l in lines)
|
|
out = []
|
|
for l in lines:
|
|
if len(l.annotations) == 0:
|
|
out.append(l.text)
|
|
else:
|
|
out.append(f"{l.text}{' ' * (maxlen - l.width)}"
|
|
f"{annotation_prefix}{l.annotations[0]}")
|
|
for a in l.annotations[1:]:
|
|
out.append(f"{' ' * maxlen}{annotation_prefix}{a}")
|
|
return out
|
|
|
|
|
|
|
|
def _format(
|
|
doc: Doc, width: int, *, use_color: bool, annotation_prefix: str,
|
|
source_map: list[list[tuple[int, int, Any]]] | None
|
|
) -> str:
|
|
lines = []
|
|
default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL)
|
|
annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM)
|
|
color_state = default_colors
|
|
source_start = 0 # The column at which the current source region starts.
|
|
source = _NO_SOURCE # The currently active source region.
|
|
line_source_map = [] # Source maps for the current line of text.
|
|
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)]
|
|
k = 0
|
|
line_text = ""
|
|
line_annotations = []
|
|
while len(agenda) > 0:
|
|
i, m, doc, color, agenda_source = agenda.pop()
|
|
if source_map is not None and agenda_source != source:
|
|
pos = len(line_text)
|
|
if source_start != pos and source is not _NO_SOURCE:
|
|
line_source_map.append((source_start, pos, source))
|
|
source = agenda_source
|
|
source_start = pos
|
|
if isinstance(doc, _NilDoc):
|
|
pass
|
|
elif isinstance(doc, _TextDoc):
|
|
color_state, color_str = _update_color(use_color, color_state, color)
|
|
line_text += color_str
|
|
line_text += doc.text
|
|
if doc.annotation is not None:
|
|
line_annotations.append(doc.annotation)
|
|
k += len(doc.text)
|
|
elif isinstance(doc, _ConcatDoc):
|
|
agenda.extend(_State(i, m, d, color, source)
|
|
for d in reversed(doc.children))
|
|
elif isinstance(doc, _BreakDoc):
|
|
if m == _BreakMode.BREAK:
|
|
if len(line_annotations) > 0:
|
|
color_state, color_str = _update_color(use_color, color_state,
|
|
annotation_colors)
|
|
line_text += color_str
|
|
lines.append(_Line(line_text, k, line_annotations))
|
|
if source_map is not None:
|
|
pos = len(line_text)
|
|
if source_start != pos and source is not _NO_SOURCE:
|
|
line_source_map.append((source_start, pos, source))
|
|
source_map.append(line_source_map)
|
|
line_source_map = []
|
|
source_start = i
|
|
line_text = " " * i
|
|
line_annotations = []
|
|
k = i
|
|
else:
|
|
color_state, color_str = _update_color(use_color, color_state, color)
|
|
line_text += color_str
|
|
line_text += doc.text
|
|
k += len(doc.text)
|
|
elif isinstance(doc, _NestDoc):
|
|
agenda.append(_State(i + doc.n, m, doc.child, color, source))
|
|
elif isinstance(doc, _GroupDoc):
|
|
# In Lindig's paper, _fits is passed the remainder of the document.
|
|
# I'm pretty sure that's a bug and we care only if the current group fits!
|
|
if (_fits(doc, width - k) and _sparse(doc)):
|
|
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source))
|
|
else:
|
|
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source))
|
|
elif isinstance(doc, _ColorDoc):
|
|
color = _ColorState(doc.foreground or color.foreground,
|
|
doc.background or color.background,
|
|
doc.intensity or color.intensity)
|
|
agenda.append(_State(i, m, doc.child, color, source))
|
|
elif isinstance(doc, _SourceMapDoc):
|
|
agenda.append(_State(i, m, doc.child, color, doc.source))
|
|
else:
|
|
raise ValueError("Invalid document ", doc)
|
|
|
|
if len(line_annotations) > 0:
|
|
color_state, color_str = _update_color(use_color, color_state,
|
|
annotation_colors)
|
|
line_text += color_str
|
|
if source_map is not None:
|
|
pos = len(line_text)
|
|
if source_start != pos and source is not _NO_SOURCE:
|
|
line_source_map.append((source_start, pos, source))
|
|
source_map.append(line_source_map)
|
|
lines.append(_Line(line_text, k, line_annotations))
|
|
out = "\n".join(_align_annotations(lines, annotation_prefix))
|
|
_, color_str = _update_color(use_color, color_state,
|
|
default_colors)
|
|
return out + color_str
|
|
|
|
|
|
|
|
|
|
# Public API.
|
|
|
|
def nil() -> Doc:
|
|
"""An empty document."""
|
|
return _nil
|
|
|
|
def text(s: str, annotation: str | None = None) -> Doc:
|
|
"""Literal text."""
|
|
return _TextDoc(s, annotation)
|
|
|
|
def concat(docs: Sequence[Doc]) -> Doc:
|
|
"""Concatenation of documents."""
|
|
docs = list(docs)
|
|
if len(docs) == 1:
|
|
return docs[0]
|
|
return _ConcatDoc(docs)
|
|
|
|
def brk(text: str = " ") -> Doc:
|
|
"""A break.
|
|
|
|
Prints either as a newline or as `text`, depending on the enclosing group.
|
|
"""
|
|
return _BreakDoc(text)
|
|
|
|
def group(doc: Doc) -> Doc:
|
|
"""Layout alternative groups.
|
|
|
|
Prints the group with its breaks as their text (typically spaces) if the
|
|
entire group would fit on the line when printed that way. Otherwise, breaks
|
|
inside the group as printed as newlines.
|
|
"""
|
|
return _GroupDoc(doc)
|
|
|
|
def nest(n: int, doc: Doc) -> Doc:
|
|
"""Increases the indentation level by `n`."""
|
|
return _NestDoc(n, doc)
|
|
|
|
|
|
def color(doc: Doc, *, foreground: Color | None = None,
|
|
background: Color | None = None,
|
|
intensity: Intensity | None = None):
|
|
"""ANSI colors.
|
|
|
|
Overrides the foreground/background/intensity of the text for the child doc.
|
|
Requires use_colors=True to be set when printing and the `colorama` package
|
|
to be installed; otherwise does nothing.
|
|
"""
|
|
return _ColorDoc(doc, foreground=foreground, background=background,
|
|
intensity=intensity)
|
|
|
|
|
|
def source_map(doc: Doc, source: Any):
|
|
"""Source mapping.
|
|
|
|
A source map associates a region of the pretty-printer's text output with a
|
|
source location that produced it. For the purposes of the pretty printer a
|
|
``source`` may be any object: we require only that we can compare sources for
|
|
equality. A text region to source object mapping can be populated as a side
|
|
output of the ``format`` method.
|
|
"""
|
|
return _SourceMapDoc(doc, source)
|
|
|
|
else:
|
|
Color = _pretty_printer.Color
|
|
Intensity = _pretty_printer.Intensity
|
|
Doc = _pretty_printer.Doc
|
|
def _format(
|
|
self, width: int = 80, *, use_color: bool | None = None,
|
|
annotation_prefix: str = " # ",
|
|
source_map: list[list[tuple[int, int, Any]]] | None = None
|
|
) -> str:
|
|
"""
|
|
Formats a pretty-printer document as a string.
|
|
|
|
Args:
|
|
source_map: for each line in the output, contains a list of
|
|
(start column, end column, source) tuples. Each tuple associates a
|
|
region of output text with a source.
|
|
"""
|
|
if use_color is None:
|
|
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
|
|
return self._format(
|
|
width, use_color=use_color, annotation_prefix=annotation_prefix,
|
|
source_map=source_map)
|
|
Doc.format = _format
|
|
Doc.__str__ = lambda self: self.format()
|
|
nil = _pretty_printer.nil
|
|
text = _pretty_printer.text
|
|
concat = _pretty_printer.concat
|
|
brk = _pretty_printer.brk
|
|
group = _pretty_printer.group
|
|
nest = _pretty_printer.nest
|
|
color = _pretty_printer.color
|
|
source_map = _pretty_printer.source_map
|
|
|
|
|
|
type_annotation = partial(color, intensity=Intensity.NORMAL,
|
|
foreground=Color.MAGENTA)
|
|
keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE)
|
|
|
|
|
|
def join(sep: Doc, docs: Sequence[Doc]) -> Doc:
|
|
"""Concatenates `docs`, separated by `sep`."""
|
|
docs = list(docs)
|
|
if len(docs) == 0:
|
|
return nil()
|
|
if len(docs) == 1:
|
|
return docs[0]
|
|
xs = [docs[0]]
|
|
for doc in docs[1:]:
|
|
xs.append(sep)
|
|
xs.append(doc)
|
|
return concat(xs)
|