from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from numbers import Integral as Integral
from typing import ClassVar
from typing_extensions import Self

import numpy as np
import scipy.sparse as sp
from numpy import ndarray
from scipy.sparse import spmatrix

from ._typing import ArrayLike, MatrixLike
from .base import (
    BaseEstimator,
    ClassifierMixin,
    MetaEstimatorMixin,
    RegressorMixin,
    clone as clone,
    is_classifier as is_classifier,
)
from .ensemble._forest import RandomForestRegressor
from .linear_model._logistic import LogisticRegression
from .model_selection import cross_val_predict as cross_val_predict
from .utils import check_random_state as check_random_state
from .utils._param_validation import HasMethods as HasMethods, StrOptions as StrOptions
from .utils.metaestimators import available_if as available_if
from .utils.multiclass import check_classification_targets as check_classification_targets
from .utils.parallel import Parallel as Parallel, delayed as delayed
from .utils.validation import check_is_fitted as check_is_fitted, has_fit_parameter as has_fit_parameter

__all__ = [
    "MultiOutputRegressor",
    "MultiOutputClassifier",
    "ClassifierChain",
    "RegressorChain",
]

class _MultiOutputEstimator(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
    _parameter_constraints: ClassVar[dict] = ...

    @abstractmethod
    def __init__(self, estimator: RandomForestRegressor, *, n_jobs=None) -> None: ...
    def partial_fit(
        self,
        X: MatrixLike | ArrayLike,
        y: MatrixLike,
        classes: Sequence[ArrayLike] | None = None,
        sample_weight: None | ArrayLike = None,
    ) -> Self: ...
    def fit(
        self,
        X: MatrixLike | ArrayLike,
        y: MatrixLike,
        sample_weight: None | ArrayLike = None,
        **fit_params,
    ) -> Self | MultiOutputRegressor: ...
    def predict(self, X: MatrixLike | ArrayLike) -> ndarray | spmatrix: ...

class MultiOutputRegressor(RegressorMixin, _MultiOutputEstimator):
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    estimators_: list[BaseEstimator] = ...

    def __init__(self, estimator: BaseEstimator | RandomForestRegressor, *, n_jobs: None | int = None) -> None: ...
    def partial_fit(
        self,
        X: MatrixLike | ArrayLike,
        y: MatrixLike,
        sample_weight: None | ArrayLike = None,
    ) -> Self: ...

class MultiOutputClassifier(ClassifierMixin, _MultiOutputEstimator):
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    estimators_: list[BaseEstimator] = ...
    classes_: ndarray = ...

    def __init__(self, estimator: BaseEstimator, *, n_jobs: None | int = None) -> None: ...
    def fit(
        self,
        X: MatrixLike | ArrayLike,
        Y: MatrixLike,
        sample_weight: None | ArrayLike = None,
        **fit_params,
    ) -> Self: ...
    def predict_proba(self, X: MatrixLike) -> ndarray | list[ndarray]: ...
    def score(self, X: MatrixLike, y: MatrixLike) -> float: ...

class _BaseChain(BaseEstimator, metaclass=ABCMeta):
    _parameter_constraints: ClassVar[dict] = ...

    def __init__(
        self, base_estimator: LogisticRegression, *, order=None, cv=None, random_state=None, verbose: bool = False
    ) -> None: ...
    @abstractmethod
    def fit(self, X: MatrixLike | ArrayLike, Y: MatrixLike, **fit_params) -> Self | ClassifierChain: ...
    def predict(self, X: MatrixLike | ArrayLike) -> ndarray: ...

class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain):
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    order_: list = ...
    estimators_: list = ...
    classes_: list = ...

    def fit(self, X: MatrixLike | ArrayLike, Y: MatrixLike) -> Self: ...
    def predict_proba(self, X: MatrixLike | ArrayLike) -> ndarray: ...
    def decision_function(self, X: MatrixLike) -> ndarray: ...

class RegressorChain(MetaEstimatorMixin, RegressorMixin, _BaseChain):
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    order_: list = ...
    estimators_: list = ...

    def fit(self, X: MatrixLike | ArrayLike, Y: MatrixLike, **fit_params) -> Self: ...
