92 lines
2.4 KiB
Python
92 lines
2.4 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Common utilities for generating source maps."""
|
|
import contextlib
|
|
import dataclasses
|
|
import re
|
|
from typing import Any, Protocol, Sequence
|
|
|
|
from absl import flags
|
|
import jax
|
|
from jax._src import sourcemap
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SourceMapDump:
|
|
"""A container for a source map and the paired generated code."""
|
|
source_map: sourcemap.SourceMap
|
|
generated_code: str
|
|
pass_name: str
|
|
|
|
|
|
class CompileFn(Protocol):
|
|
|
|
def __call__(self, work_dir, fn, f_args, f_kwargs, **kwargs) -> Any:
|
|
...
|
|
|
|
|
|
class GenerateDumpFn(Protocol):
|
|
|
|
def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump:
|
|
...
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Pass:
|
|
name: str
|
|
compile_fn: CompileFn
|
|
generate_dump: GenerateDumpFn
|
|
|
|
|
|
_pass_registry = {}
|
|
|
|
|
|
def register_pass(pass_: Pass):
|
|
if pass_.name in _pass_registry:
|
|
raise ValueError(f"Pass {pass_.name} already registered")
|
|
_pass_registry[pass_.name] = pass_
|
|
|
|
|
|
def all_passes() -> Sequence[Pass]:
|
|
return list(_pass_registry.values())
|
|
|
|
|
|
def filter_passes(regex: str) -> Sequence[Pass]:
|
|
"""Gets all registered passes whose display name matches the given regex."""
|
|
return [
|
|
pass_
|
|
for pass_ in _pass_registry.values()
|
|
if re.match(regex, pass_.name)
|
|
]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def flag_env(**kwargs):
|
|
"""A context manager for setting and restoring flags."""
|
|
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
|
|
for kwarg, new_value in kwargs.items():
|
|
setattr(flags.FLAGS, kwarg, new_value)
|
|
try:
|
|
yield
|
|
finally:
|
|
for kwarg, old_value in old_flags.items():
|
|
setattr(flags.FLAGS, kwarg, old_value)
|
|
|
|
|
|
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
|
|
with flag_env(**env_flags):
|
|
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda
|
|
*f_args, **f_kwargs
|
|
).compile(compiler_flags)
|