# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Source mapping generator for Jaxprs.""" import re from typing import Any import jax from jax._src import config from jax._src import core from jax._src import source_info_util from jax._src import sourcemap from jax.experimental.source_mapper import common source_info_util.register_exclusion(__file__) def compile_jaxpr(work_dir, f, f_args, f_kwargs, **_): del work_dir return jax.make_jaxpr(f)(*f_args, **f_kwargs) def canonicalize_filename(file_name: str): pattern = config.hlo_source_file_canonicalization_regex.value if pattern: file_name = re.sub(pattern, '', file_name) return file_name def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump: pprint_mappings: list[list[tuple[int, int, Any]]] = [] pprint_str = jaxpr.pretty_print(source_map=pprint_mappings) used_source_files = [] mappings = sourcemap.MappingsGenerator() for pprint_map_line in pprint_mappings: mappings.new_group() for pprint_segment in pprint_map_line: start_col, end_col, frame = pprint_segment del end_col file_name = canonicalize_filename(frame.file_name) if file_name not in used_source_files: used_source_files.append(file_name) file_idx = used_source_files.index(file_name) src_line = frame.start_line - 1 # Zero-indexed src_col = frame.start_column # A segment is a tuple of the form: # (generated_col, src_file_idx, src_line, src_col) mappings.new_segment(start_col, file_idx, src_line, src_col) mappings.new_group() source_map = sourcemap.SourceMap( version=3, sources=used_source_files, sources_content=[], mappings=mappings.mappings(), names=[], ) return common.SourceMapDump( source_map=source_map, generated_code=pprint_str, pass_name='jaxpr', ) common.register_pass( common.Pass( name='jaxpr', compile_fn=compile_jaxpr, # type: ignore[arg-type] generate_dump=make_jaxpr_dump, # type: ignore[arg-type] ) )