Source code for mljet.contrib.analyzer

"""Models analyzed and method's extractor."""

import inspect
import json
import logging
import warnings
from typing import (
    Callable,
    Dict,
    List,
)

from mljet.contrib.supported import ModelType
from mljet.cookie.templates.ml.dispatcher import get_dual_methods
from mljet.utils.types import Estimator

_SUPPORTED_METHODS = (
    "predict",
    "predict_proba",
)

log = logging.getLogger(__name__)


[docs]def extract_methods_names(model: Estimator) -> List[str]: """Get methods from model.""" return [ member[0] for member in inspect.getmembers(model, inspect.ismethod) if not member[0].startswith("_") and member[0].startswith("predict") and member[0] in _SUPPORTED_METHODS ]
[docs]def get_associated_methods_wrappers( model: Estimator, ) -> Dict[str, Callable]: """Get methods names and associated wrappers.""" with warnings.catch_warnings(): warnings.simplefilter("ignore") extracted = extract_methods_names(model) mt = ModelType.from_model(model) associated = dict(zip(extracted, get_dual_methods(mt, extracted))) log.info(f"Detected model methods: {json.dumps(extracted, indent=4)}") return associated