# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common utilities for generating source maps.""" import contextlib import dataclasses import re from typing import Any, Protocol, Sequence from absl import flags import jax from jax._src import sourcemap @dataclasses.dataclass(frozen=True) class SourceMapDump: """A container for a source map and the paired generated code.""" source_map: sourcemap.SourceMap generated_code: str pass_name: str class CompileFn(Protocol): def __call__(self, work_dir, fn, f_args, f_kwargs, **kwargs) -> Any: ... class GenerateDumpFn(Protocol): def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump: ... @dataclasses.dataclass(frozen=True) class Pass: name: str compile_fn: CompileFn generate_dump: GenerateDumpFn _pass_registry = {} def register_pass(pass_: Pass): if pass_.name in _pass_registry: raise ValueError(f"Pass {pass_.name} already registered") _pass_registry[pass_.name] = pass_ def all_passes() -> Sequence[Pass]: return list(_pass_registry.values()) def filter_passes(regex: str) -> Sequence[Pass]: """Gets all registered passes whose display name matches the given regex.""" return [ pass_ for pass_ in _pass_registry.values() if re.match(regex, pass_.name) ] @contextlib.contextmanager def flag_env(**kwargs): """A context manager for setting and restoring flags.""" old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs} for kwarg, new_value in kwargs.items(): setattr(flags.FLAGS, kwarg, new_value) try: yield finally: for kwarg, old_value in old_flags.items(): setattr(flags.FLAGS, kwarg, old_value) def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags): with flag_env(**env_flags): jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda *f_args, **f_kwargs ).compile(compiler_flags)