from abc import ABCMeta, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Callable, ClassVar, Generic, TypeVar
from typing_extensions import Self

from numpy import ndarray
from numpy.random import RandomState
from scipy.sparse import spmatrix

from .._typing import ArrayLike, Float, Int, MatrixLike
from ..base import BaseEstimator, MetaEstimatorMixin
from . import BaseCrossValidator

_BaseEstimatorT = TypeVar("_BaseEstimatorT", bound=BaseEstimator, default=BaseEstimator, covariant=True)

__all__ = ["GridSearchCV", "ParameterGrid", "ParameterSampler", "RandomizedSearchCV"]

class ParameterGrid:
    def __init__(self, param_grid: Sequence[Mapping[str, Sequence]] | Mapping[str, Sequence]) -> None: ...
    def __iter__(self) -> Iterator[dict[str, Any]]: ...
    def __len__(self) -> int: ...
    def __getitem__(self, ind: Int) -> dict[str, Any]: ...

class ParameterSampler:
    def __init__(
        self,
        param_distributions: dict,
        n_iter: Int,
        *,
        random_state: RandomState | None | Int = None,
    ) -> None: ...
    def __iter__(self): ...
    def __len__(self) -> int: ...

class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
    @abstractmethod
    def __init__(
        self,
        estimator,
        *,
        scoring=None,
        n_jobs=None,
        refit: bool = True,
        cv=None,
        verbose: int = 0,
        pre_dispatch: str = "2*n_jobs",
        error_score=...,
        return_train_score: bool = True,
    ) -> None: ...
    def score(self, X: list[str] | MatrixLike, y: None | MatrixLike | ArrayLike = None) -> Float: ...
    def score_samples(self, X: Iterable) -> ndarray: ...
    def predict(self, X: ArrayLike) -> ndarray: ...
    def predict_proba(self, X: ArrayLike) -> ndarray: ...
    def predict_log_proba(self, X: ArrayLike) -> ndarray: ...
    def decision_function(self, X: ArrayLike) -> ndarray: ...
    def transform(self, X: ArrayLike) -> ndarray | spmatrix: ...
    def inverse_transform(self, Xt: ArrayLike) -> ndarray | spmatrix: ...
    def n_features_in_(self): ...
    def classes_(self): ...
    def fit(
        self,
        X: list[str] | MatrixLike,
        y: None | MatrixLike | ArrayLike = None,
        *,
        groups: None | ArrayLike = None,
        **fit_params,
    ) -> Self: ...

class GridSearchCV(BaseSearchCV, Generic[_BaseEstimatorT]):
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    classes_: ndarray = ...
    multimetric_: bool = ...
    refit_time_: float = ...
    n_splits_: int = ...
    scorer_: Callable | dict = ...
    best_index_: int = ...
    best_params_: dict = ...
    best_score_: float = ...
    best_estimator_: _BaseEstimatorT = ...
    cv_results_: dict[str, ndarray] = ...

    _required_parameters: ClassVar[list] = ...

    def __init__(
        self,
        estimator: _BaseEstimatorT,
        param_grid: Mapping | Sequence[dict],
        *,
        scoring: ArrayLike | None | tuple | Callable | Mapping | str = None,
        n_jobs: None | Int = None,
        refit: str | Callable | bool = True,
        cv: int | BaseCrossValidator | Iterable | None = None,
        verbose: Int = 0,
        pre_dispatch: str | int = "2*n_jobs",
        error_score: str | Float = ...,
        return_train_score: bool = False,
    ) -> None: ...

class RandomizedSearchCV(BaseSearchCV):
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    classes_: ndarray = ...
    multimetric_: bool = ...
    refit_time_: float = ...
    n_splits_: int = ...
    scorer_: Callable | dict = ...
    best_index_: int = ...
    best_params_: dict = ...
    best_score_: float = ...
    best_estimator_: BaseEstimator = ...
    cv_results_: dict[str, ndarray] = ...

    _required_parameters: ClassVar[list] = ...

    def __init__(
        self,
        estimator: BaseEstimator,
        param_distributions: Sequence[Mapping] | Mapping,
        *,
        n_iter: Int = 10,
        scoring: ArrayLike | None | tuple | Callable | Mapping | str = None,
        n_jobs: None | Int = None,
        refit: str | Callable | bool = True,
        cv: int | BaseCrossValidator | Iterable | None = None,
        verbose: Int = 0,
        pre_dispatch: str | int = "2*n_jobs",
        random_state: RandomState | None | Int = None,
        error_score: str | Float = ...,
        return_train_score: bool = False,
    ) -> None: ...
