import collections.abc
import functools
import typing
from typing import TypeVar, Callable, Any, Dict, Union, Collection

from .exception import ComplexItemException
from .extras import first

T = TypeVar('T')
CastType = TypeVar('CastType', type, typing.Mapping, typing.Collection, Callable[[Any], Any])


def no_cast(x):
    return x


# TODO: rename to RecursiveCaster
class BaseCaster:
    def __init__(self):
        self._key_to_cast: Dict[Any, Callable[[Any], Any]] = {}

    def set_cast_for_key(self, key: Any, cast: typing.Optional[Callable[[Any], Any]]):
        if cast is None:
            del self._key_to_cast[key]
        else:
            self._key_to_cast[key] = cast

    @staticmethod
    def _get_inner_type(target) -> typing.Optional[Union[tuple, list]]:
        if isinstance(target, collections.abc.Collection):
            if not target:
                return []
            if len(target) > 1:
                raise KeyError('Expected at most one inner type')
            if isinstance(target, collections.abc.Mapping):
                return first(target.items(), none_if_empty=True)
            return [first(target)]
        else:
            type_args = tuple(arg for arg in getattr(target, '__args__', ()) if not isinstance(arg, typing.TypeVar))
            if type_args:
                return type_args
        return None

    @typing.overload
    def cast(self, value: Any, target_type: typing.Type[CastType]) -> CastType: pass
    @typing.overload
    def cast(self, value: Any, target_type: CastType) -> CastType: pass

    def cast(self, value: Any, target_type: Union[CastType, typing.Type[CastType]]) -> CastType:
        converter = self._get_converter(target_type)
        return converter(value)

    def _get_converter(self, target_type: Any) -> Any:
        origin_type = getattr(target_type, '__origin__', None)
        # origin_type = typing.get_origin(target_type)  # 3.8+
        if origin_type:  # dict[int, str], typing.Dict[int, str]. list[bool], ...
            outer_type = origin_type
        elif isinstance(target_type, type):
            outer_type = target_type  # dict, list, bool, ...
        else:
            outer_type = type(target_type)  # {}, [], set(), ...

        if origin_type is typing.Union:  # Could not find a clearer way to determine Union :(
            inner_types = self._get_inner_type(target_type)
            converters = [self._get_converter(type_) for type_ in inner_types]
            return functools.partial(self._return_first_success, converters=converters)

        if issubclass(outer_type, collections.abc.Mapping):
            key_converter = value_converter = no_cast
            if inner_type := self._get_inner_type(target_type):
                key_converter = self._get_converter(inner_type[0])
                value_converter = self._get_converter(inner_type[1])

            to_mapping = functools.partial(
                self._to_mapping, outer_type=outer_type, key_converter=key_converter, value_converter=value_converter)
            return to_mapping

        if not issubclass(outer_type, str) and issubclass(outer_type, collections.abc.Collection):
            inner_converter = no_cast
            if inner_type := self._get_inner_type(target_type):
                inner_converter = self._get_converter(inner_type[0])

            to_collection = functools.partial(
                self._to_collection, outer_type=outer_type, inner_converter=inner_converter)
            return to_collection

        converter = self._key_to_cast.get(target_type)
        if converter:
            return converter

        if callable(target_type):
            if isinstance(target_type, type):
                return functools.partial(self._cast_if_not_is_instance, target_type=target_type)
            return target_type

        raise KeyError(f'no converter registered for target_type: {target_type}, and target_type is not callable')

    @staticmethod
    def _to_collection(value: Any, outer_type: type, inner_converter: Callable) -> Collection:
        if isinstance(value, collections.abc.Mapping) or not isinstance(value, collections.abc.Iterable):
            raise ValueError(f'expected iterable, got: {type(value)}{value!r:.100}')
        return outer_type(inner_converter(val) for val in value)

    @staticmethod
    def _to_mapping(value: Any, outer_type: type, key_converter: Callable,
                    value_converter: Callable) -> typing.Mapping:
        if not isinstance(value, collections.abc.Mapping):
            raise ValueError(f'expected Mapping, got: {type(value)}{value!r:.100}')
        result = outer_type()
        for key, val in value.items():
            result[key_converter(key)] = value_converter(val)
        return result

    @staticmethod
    def _return_first_success(value: Any, converters: typing.Iterable[Callable[[Any], Any]]) -> Any:
        exceptions = []
        for converter in converters:
            try:
                return converter(value)
            except (ValueError, TypeError, ComplexItemException) as e:
                exceptions.append(repr(e))

        raise ComplexItemException(exceptions)

    @staticmethod
    def _cast_if_not_is_instance(value: Any, target_type: type[T]) -> T:
        if isinstance(value, target_type):
            return value
        return target_type(value)
