2025-08-11 12:24:21 +08:00

81 lines
2.5 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Source mapping generator for Jaxprs."""
import re
from typing import Any
import jax
from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import sourcemap
from jax.experimental.source_mapper import common
source_info_util.register_exclusion(__file__)
def compile_jaxpr(work_dir, f, f_args, f_kwargs, **_):
del work_dir
return jax.make_jaxpr(f)(*f_args, **f_kwargs)
def canonicalize_filename(file_name: str):
pattern = config.hlo_source_file_canonicalization_regex.value
if pattern:
file_name = re.sub(pattern, '', file_name)
return file_name
def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
pprint_mappings: list[list[tuple[int, int, Any]]] = []
pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
used_source_files = []
mappings = sourcemap.MappingsGenerator()
for pprint_map_line in pprint_mappings:
mappings.new_group()
for pprint_segment in pprint_map_line:
start_col, end_col, frame = pprint_segment
del end_col
file_name = canonicalize_filename(frame.file_name)
if file_name not in used_source_files:
used_source_files.append(file_name)
file_idx = used_source_files.index(file_name)
src_line = frame.start_line - 1 # Zero-indexed
src_col = frame.start_column
# A segment is a tuple of the form:
# (generated_col, src_file_idx, src_line, src_col)
mappings.new_segment(start_col, file_idx, src_line, src_col)
mappings.new_group()
source_map = sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)
return common.SourceMapDump(
source_map=source_map,
generated_code=pprint_str,
pass_name='jaxpr',
)
common.register_pass(
common.Pass(
name='jaxpr',
compile_fn=compile_jaxpr, # type: ignore[arg-type]
generate_dump=make_jaxpr_dump, # type: ignore[arg-type]
)
)