81 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			81 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2025 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """Source mapping generator for Jaxprs."""
 | |
| import re
 | |
| from typing import Any
 | |
| 
 | |
| import jax
 | |
| from jax._src import config
 | |
| from jax._src import core
 | |
| from jax._src import source_info_util
 | |
| from jax._src import sourcemap
 | |
| from jax.experimental.source_mapper import common
 | |
| 
 | |
| source_info_util.register_exclusion(__file__)
 | |
| 
 | |
| 
 | |
| def compile_jaxpr(work_dir, f, f_args, f_kwargs, **_):
 | |
|   del work_dir
 | |
|   return jax.make_jaxpr(f)(*f_args, **f_kwargs)
 | |
| 
 | |
| 
 | |
| def canonicalize_filename(file_name: str):
 | |
|   pattern = config.hlo_source_file_canonicalization_regex.value
 | |
|   if pattern:
 | |
|     file_name = re.sub(pattern, '', file_name)
 | |
|   return file_name
 | |
| 
 | |
| 
 | |
| def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
 | |
|   pprint_mappings: list[list[tuple[int, int, Any]]] = []
 | |
|   pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
 | |
|   used_source_files = []
 | |
|   mappings = sourcemap.MappingsGenerator()
 | |
|   for pprint_map_line in pprint_mappings:
 | |
|     mappings.new_group()
 | |
|     for pprint_segment in pprint_map_line:
 | |
|       start_col, end_col, frame = pprint_segment
 | |
|       del end_col
 | |
|       file_name = canonicalize_filename(frame.file_name)
 | |
|       if file_name not in used_source_files:
 | |
|         used_source_files.append(file_name)
 | |
|       file_idx = used_source_files.index(file_name)
 | |
|       src_line = frame.start_line - 1  # Zero-indexed
 | |
|       src_col = frame.start_column
 | |
|       # A segment is a tuple of the form:
 | |
|       # (generated_col, src_file_idx, src_line, src_col)
 | |
|       mappings.new_segment(start_col, file_idx, src_line, src_col)
 | |
|   mappings.new_group()
 | |
|   source_map = sourcemap.SourceMap(
 | |
|       version=3,
 | |
|       sources=used_source_files,
 | |
|       sources_content=[],
 | |
|       mappings=mappings.mappings(),
 | |
|       names=[],
 | |
|   )
 | |
|   return common.SourceMapDump(
 | |
|       source_map=source_map,
 | |
|       generated_code=pprint_str,
 | |
|       pass_name='jaxpr',
 | |
|   )
 | |
| 
 | |
| 
 | |
| common.register_pass(
 | |
|     common.Pass(
 | |
|         name='jaxpr',
 | |
|         compile_fn=compile_jaxpr,  # type: ignore[arg-type]
 | |
|         generate_dump=make_jaxpr_dump,  # type: ignore[arg-type]
 | |
|     )
 | |
| )
 |