2025-08-11 12:24:21 +08:00

541 lines
17 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Wadler-Lindig pretty printer.
#
# References:
# Wadler, P., 1998. A prettier printer. Journal of Functional Programming,
# pp.223-244.
#
# Lindig, C. 2000. Strictly Pretty.
# https://lindig.github.io/papers/strictly-pretty-2000.pdf
#
# Hafiz, A. 2021. Strictly Annotated: A Pretty-Printer With Support for
# Annotations. https://ayazhafiz.com/articles/21/strictly-annotated
#
from __future__ import annotations
from collections.abc import Sequence
import enum
from functools import partial
import sys
from typing import Any, NamedTuple, TYPE_CHECKING
from jax._src import config
from jax._src import util
from jax._src.lib import _pretty_printer as _pretty_printer
_PPRINT_USE_COLOR = config.bool_state(
'jax_pprint_use_color',
True,
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
)
def _can_use_color() -> bool:
try:
# Check if we're in IPython or Colab
ipython = get_ipython() # type: ignore[name-defined]
shell = ipython.__class__.__name__
if shell == "ZMQInteractiveShell":
# Jupyter Notebook
return True
elif "colab" in str(ipython.__class__):
# Google Colab (external or internal)
return True
except NameError:
pass
# Otherwise check if we're in a terminal
return hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
CAN_USE_COLOR = _can_use_color()
# TODO(phawkins): remove this condition after the jaxlib 0.6.3 release.
if TYPE_CHECKING or _pretty_printer is None:
try:
import colorama # pytype: disable=import-error
except ImportError:
colorama = None
class Doc(util.StrictABC):
__slots__ = ()
def format(
self, width: int = 80, *, use_color: bool | None = None,
annotation_prefix: str = " # ",
source_map: list[list[tuple[int, int, Any]]] | None = None
) -> str:
"""
Formats a pretty-printer document as a string.
Args:
source_map: for each line in the output, contains a list of
(start column, end column, source) tuples. Each tuple associates a
region of output text with a source.
"""
if use_color is None:
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
return _format(self, width, use_color=use_color,
annotation_prefix=annotation_prefix, source_map=source_map)
def __str__(self):
return self.format()
def __add__(self, other: Doc) -> Doc:
return concat([self, other])
def num_annotations(self) -> int:
raise NotImplementedError()
class _NilDoc(Doc):
def __repr__(self): return "nil"
def num_annotations(self) -> int:
return 0
_nil = _NilDoc()
class _TextDoc(Doc):
__slots__ = ("text", "annotation")
text: str
annotation: str | None
def __init__(self, text: str, annotation: str | None = None):
assert isinstance(text, str), text
assert annotation is None or isinstance(annotation, str), annotation
self.text = text
self.annotation = annotation
def __repr__(self):
if self.annotation is not None:
return f"text(\"{self.text}\", annotation=\"{self.annotation}\")"
else:
return f"text(\"{self.text}\")"
def num_annotations(self) -> int:
return 1 if self.annotation is not None else 0
class _ConcatDoc(Doc):
__slots__ = ("children", "_num_annotations")
children: list[Doc]
_num_annotations: int
def __init__(self, children: Sequence[Doc]):
self.children = list(children)
self._num_annotations = sum(child.num_annotations() for child in children)
def __repr__(self): return f"concat({self.children})"
def num_annotations(self) -> int:
return self._num_annotations
class _BreakDoc(Doc):
__slots__ = ("text",)
text: str
def __init__(self, text: str):
assert isinstance(text, str), text
self.text = text
def __repr__(self): return f"break({self.text})"
def num_annotations(self) -> int:
return 0
class _GroupDoc(Doc):
__slots__ = ("child",)
child: Doc
def __init__(self, child: Doc):
assert isinstance(child, Doc), child
self.child = child
def __repr__(self): return f"group({self.child})"
def num_annotations(self) -> int:
return self.child.num_annotations()
class _NestDoc(Doc):
__slots__ = ("n", "child",)
n: int
child: Doc
def __init__(self, n: int, child: Doc):
assert isinstance(child, Doc), child
self.n = n
self.child = child
def __repr__(self): return f"nest({self.n, self.child})"
def num_annotations(self) -> int:
return self.child.num_annotations()
_NO_SOURCE = object()
class _SourceMapDoc(Doc):
__slots__ = ("child", "source")
child: Doc
source: Any
def __init__(self, child: Doc, source: Any):
assert isinstance(child, Doc), child
self.child = child
self.source = source
def __repr__(self): return f"source({self.child}, {self.source})"
def num_annotations(self) -> int:
return self.child.num_annotations()
Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE",
"MAGENTA", "CYAN", "WHITE", "RESET"])
Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"])
class _ColorDoc(Doc):
__slots__ = ("foreground", "background", "intensity", "child")
foreground: Color | None
background: Color | None
intensity: Intensity | None
child: Doc
def __init__(self, child: Doc, *, foreground: Color | None = None,
background: Color | None = None,
intensity: Intensity | None = None):
assert isinstance(child, Doc), child
self.child = child
self.foreground = foreground
self.background = background
self.intensity = intensity
def num_annotations(self) -> int:
return self.child.num_annotations()
_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"])
# In Lindig's paper fits() and format() are defined recursively. This is a
# non-recursive formulation using an explicit stack, necessary because Python
# doesn't have a tail recursion optimization.
def _fits(doc: Doc, width: int) -> bool:
agenda = [doc]
while width >= 0 and len(agenda) > 0:
doc = agenda.pop()
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
width -= len(doc.text)
elif isinstance(doc, _ConcatDoc):
agenda.extend(reversed(doc.children))
elif isinstance(doc, _BreakDoc):
width -= len(doc.text)
elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)):
agenda.append(doc.child)
else:
raise ValueError("Invalid document ", doc)
return width >= 0
# Annotation layout: A flat group is sparse if there are no breaks between
# annotations.
def _sparse(doc: Doc) -> bool:
agenda = [doc]
if doc.num_annotations() == 0:
return True
num_annotations = 0
seen_break = False
while len(agenda) > 0:
doc = agenda.pop()
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
if doc.annotation is not None:
if num_annotations >= 1 and seen_break:
return False
num_annotations += 1
elif isinstance(doc, _ConcatDoc):
agenda.extend(reversed(doc.children))
elif isinstance(doc, _BreakDoc):
seen_break = True
elif isinstance(doc, _NestDoc):
agenda.append(doc.child)
elif isinstance(doc, _GroupDoc):
agenda.append(doc.child)
elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc):
agenda.append(doc.child)
else:
raise ValueError("Invalid document ", doc)
return True
class _ColorState(NamedTuple):
foreground: Color
background: Color
intensity: Intensity
class _State(NamedTuple):
indent: int
mode: _BreakMode
doc: Doc
color: _ColorState
source_map: Any
class _Line(NamedTuple):
text: str
width: int
annotations: list[str]
def _update_color(use_color: bool, state: _ColorState, update: _ColorState
) -> tuple[_ColorState, str]:
if not use_color or colorama is None:
return update, ""
color_str = ""
if state.foreground != update.foreground:
color_str += getattr(colorama.Fore, str(update.foreground.name))
if state.background != update.background:
color_str += getattr(colorama.Back, str(update.background.name))
if state.intensity != update.intensity:
color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands
color_str += getattr(colorama.Style, str(update.intensity.name))
return update, color_str
def _align_annotations(lines: list[_Line], annotation_prefix: str) -> list[str]:
# TODO: Hafiz also implements a local alignment mode, where groups of lines
# with annotations are aligned together.
maxlen = max(l.width for l in lines)
out = []
for l in lines:
if len(l.annotations) == 0:
out.append(l.text)
else:
out.append(f"{l.text}{' ' * (maxlen - l.width)}"
f"{annotation_prefix}{l.annotations[0]}")
for a in l.annotations[1:]:
out.append(f"{' ' * maxlen}{annotation_prefix}{a}")
return out
def _format(
doc: Doc, width: int, *, use_color: bool, annotation_prefix: str,
source_map: list[list[tuple[int, int, Any]]] | None
) -> str:
lines = []
default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL)
annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM)
color_state = default_colors
source_start = 0 # The column at which the current source region starts.
source = _NO_SOURCE # The currently active source region.
line_source_map = [] # Source maps for the current line of text.
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)]
k = 0
line_text = ""
line_annotations = []
while len(agenda) > 0:
i, m, doc, color, agenda_source = agenda.pop()
if source_map is not None and agenda_source != source:
pos = len(line_text)
if source_start != pos and source is not _NO_SOURCE:
line_source_map.append((source_start, pos, source))
source = agenda_source
source_start = pos
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
color_state, color_str = _update_color(use_color, color_state, color)
line_text += color_str
line_text += doc.text
if doc.annotation is not None:
line_annotations.append(doc.annotation)
k += len(doc.text)
elif isinstance(doc, _ConcatDoc):
agenda.extend(_State(i, m, d, color, source)
for d in reversed(doc.children))
elif isinstance(doc, _BreakDoc):
if m == _BreakMode.BREAK:
if len(line_annotations) > 0:
color_state, color_str = _update_color(use_color, color_state,
annotation_colors)
line_text += color_str
lines.append(_Line(line_text, k, line_annotations))
if source_map is not None:
pos = len(line_text)
if source_start != pos and source is not _NO_SOURCE:
line_source_map.append((source_start, pos, source))
source_map.append(line_source_map)
line_source_map = []
source_start = i
line_text = " " * i
line_annotations = []
k = i
else:
color_state, color_str = _update_color(use_color, color_state, color)
line_text += color_str
line_text += doc.text
k += len(doc.text)
elif isinstance(doc, _NestDoc):
agenda.append(_State(i + doc.n, m, doc.child, color, source))
elif isinstance(doc, _GroupDoc):
# In Lindig's paper, _fits is passed the remainder of the document.
# I'm pretty sure that's a bug and we care only if the current group fits!
if (_fits(doc, width - k) and _sparse(doc)):
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source))
else:
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source))
elif isinstance(doc, _ColorDoc):
color = _ColorState(doc.foreground or color.foreground,
doc.background or color.background,
doc.intensity or color.intensity)
agenda.append(_State(i, m, doc.child, color, source))
elif isinstance(doc, _SourceMapDoc):
agenda.append(_State(i, m, doc.child, color, doc.source))
else:
raise ValueError("Invalid document ", doc)
if len(line_annotations) > 0:
color_state, color_str = _update_color(use_color, color_state,
annotation_colors)
line_text += color_str
if source_map is not None:
pos = len(line_text)
if source_start != pos and source is not _NO_SOURCE:
line_source_map.append((source_start, pos, source))
source_map.append(line_source_map)
lines.append(_Line(line_text, k, line_annotations))
out = "\n".join(_align_annotations(lines, annotation_prefix))
_, color_str = _update_color(use_color, color_state,
default_colors)
return out + color_str
# Public API.
def nil() -> Doc:
"""An empty document."""
return _nil
def text(s: str, annotation: str | None = None) -> Doc:
"""Literal text."""
return _TextDoc(s, annotation)
def concat(docs: Sequence[Doc]) -> Doc:
"""Concatenation of documents."""
docs = list(docs)
if len(docs) == 1:
return docs[0]
return _ConcatDoc(docs)
def brk(text: str = " ") -> Doc:
"""A break.
Prints either as a newline or as `text`, depending on the enclosing group.
"""
return _BreakDoc(text)
def group(doc: Doc) -> Doc:
"""Layout alternative groups.
Prints the group with its breaks as their text (typically spaces) if the
entire group would fit on the line when printed that way. Otherwise, breaks
inside the group as printed as newlines.
"""
return _GroupDoc(doc)
def nest(n: int, doc: Doc) -> Doc:
"""Increases the indentation level by `n`."""
return _NestDoc(n, doc)
def color(doc: Doc, *, foreground: Color | None = None,
background: Color | None = None,
intensity: Intensity | None = None):
"""ANSI colors.
Overrides the foreground/background/intensity of the text for the child doc.
Requires use_colors=True to be set when printing and the `colorama` package
to be installed; otherwise does nothing.
"""
return _ColorDoc(doc, foreground=foreground, background=background,
intensity=intensity)
def source_map(doc: Doc, source: Any):
"""Source mapping.
A source map associates a region of the pretty-printer's text output with a
source location that produced it. For the purposes of the pretty printer a
``source`` may be any object: we require only that we can compare sources for
equality. A text region to source object mapping can be populated as a side
output of the ``format`` method.
"""
return _SourceMapDoc(doc, source)
else:
Color = _pretty_printer.Color
Intensity = _pretty_printer.Intensity
Doc = _pretty_printer.Doc
def _format(
self, width: int = 80, *, use_color: bool | None = None,
annotation_prefix: str = " # ",
source_map: list[list[tuple[int, int, Any]]] | None = None
) -> str:
"""
Formats a pretty-printer document as a string.
Args:
source_map: for each line in the output, contains a list of
(start column, end column, source) tuples. Each tuple associates a
region of output text with a source.
"""
if use_color is None:
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
return self._format(
width, use_color=use_color, annotation_prefix=annotation_prefix,
source_map=source_map)
Doc.format = _format
Doc.__str__ = lambda self: self.format()
nil = _pretty_printer.nil
text = _pretty_printer.text
concat = _pretty_printer.concat
brk = _pretty_printer.brk
group = _pretty_printer.group
nest = _pretty_printer.nest
color = _pretty_printer.color
source_map = _pretty_printer.source_map
type_annotation = partial(color, intensity=Intensity.NORMAL,
foreground=Color.MAGENTA)
keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE)
def join(sep: Doc, docs: Sequence[Doc]) -> Doc:
"""Concatenates `docs`, separated by `sep`."""
docs = list(docs)
if len(docs) == 0:
return nil()
if len(docs) == 1:
return docs[0]
xs = [docs[0]]
for doc in docs[1:]:
xs.append(sep)
xs.append(doc)
return concat(xs)