81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Source mapping generator for Jaxprs."""
|
|
import re
|
|
from typing import Any
|
|
|
|
import jax
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import source_info_util
|
|
from jax._src import sourcemap
|
|
from jax.experimental.source_mapper import common
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
|
|
|
|
def compile_jaxpr(work_dir, f, f_args, f_kwargs, **_):
|
|
del work_dir
|
|
return jax.make_jaxpr(f)(*f_args, **f_kwargs)
|
|
|
|
|
|
def canonicalize_filename(file_name: str):
|
|
pattern = config.hlo_source_file_canonicalization_regex.value
|
|
if pattern:
|
|
file_name = re.sub(pattern, '', file_name)
|
|
return file_name
|
|
|
|
|
|
def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
|
|
pprint_mappings: list[list[tuple[int, int, Any]]] = []
|
|
pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
|
|
used_source_files = []
|
|
mappings = sourcemap.MappingsGenerator()
|
|
for pprint_map_line in pprint_mappings:
|
|
mappings.new_group()
|
|
for pprint_segment in pprint_map_line:
|
|
start_col, end_col, frame = pprint_segment
|
|
del end_col
|
|
file_name = canonicalize_filename(frame.file_name)
|
|
if file_name not in used_source_files:
|
|
used_source_files.append(file_name)
|
|
file_idx = used_source_files.index(file_name)
|
|
src_line = frame.start_line - 1 # Zero-indexed
|
|
src_col = frame.start_column
|
|
# A segment is a tuple of the form:
|
|
# (generated_col, src_file_idx, src_line, src_col)
|
|
mappings.new_segment(start_col, file_idx, src_line, src_col)
|
|
mappings.new_group()
|
|
source_map = sourcemap.SourceMap(
|
|
version=3,
|
|
sources=used_source_files,
|
|
sources_content=[],
|
|
mappings=mappings.mappings(),
|
|
names=[],
|
|
)
|
|
return common.SourceMapDump(
|
|
source_map=source_map,
|
|
generated_code=pprint_str,
|
|
pass_name='jaxpr',
|
|
)
|
|
|
|
|
|
common.register_pass(
|
|
common.Pass(
|
|
name='jaxpr',
|
|
compile_fn=compile_jaxpr, # type: ignore[arg-type]
|
|
generate_dump=make_jaxpr_dump, # type: ignore[arg-type]
|
|
)
|
|
)
|