Source code for mljet.cookie.templates.ml.dispatcher

"""Dispatcher for supported model types."""
import logging
import sys
from functools import lru_cache
from importlib.util import (
    module_from_spec,
    spec_from_file_location,
)
from pathlib import Path
from types import ModuleType
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Sequence,
)

from mljet.contrib.supported import ModelType

BASES_PATH = Path(__file__).parent.resolve()

log = logging.getLogger(__name__)


[docs]@lru_cache(None) def get_all_supported_ml_kinds() -> Dict[str, ModuleType]: """Returns all default backends.""" supported2mod = {} backends_files = filter( lambda y: ( y.is_file() and not y.name.startswith("__") and y.name != Path(__file__).name ), BASES_PATH.rglob("*.py"), ) for file in backends_files: spec = spec_from_file_location(file.stem, file) if not spec: continue mod = module_from_spec(spec) sys.modules[file.stem] = mod spec.loader.exec_module(mod) # type: ignore used_for = getattr(mod, "USED_FOR", []) if not used_for: log.critical( f"Module {file.stem} exists but has no `USED_FOR`, skipping" ) for mt in used_for: supported2mod[mt] = mod return supported2mod
SUPPORTED_ML_KINDS = get_all_supported_ml_kinds()
[docs]def get_dual_methods(mt: ModelType, methods: Sequence[str]) -> List[Callable]: """Get dual methods, needed to replace in backend templates.""" mod = SUPPORTED_ML_KINDS.get(mt) if not mod: raise ValueError(f"No such model type: {mt}") dual: List[Callable] = [] for method in methods: w: Optional[Callable] = getattr(mod, method, None) if w is None: raise ValueError(f"Method `{method}` not supported for {mt}") dual.append(w) return dual