#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import abc
import asyncio
import itertools
from dataclasses import dataclass, field
from enum import StrEnum
from types import UnionType
from typing import Any, ClassVar, Callable

from async_tools import acall
from dynamic_types.class_name import _prepare_class_name
from dynamic_types.create import create_type
from http_tools import IncomingRequest
from http_tools.mime_types import ContentType
from init_helpers.dict_to_dataclass import NoValue, convert_to_type

from openapi_tools.extras import raise_if


class ParameterLocation(StrEnum):
    query = 'query'
    path = 'path'
    header = 'header'
    cookie = 'cookie'
    body = 'body'  # INNER: not present in specification, used inside library
    security = 'security'  # INNER: not present in specification, used inside library
    literal = 'literal'  # INNER: not present in specification, used inside library


@dataclass(frozen=True, kw_only=True)
class CallParameter(abc.ABC):
    """Параметр вызова процедуры, schema - тип, key - ключ при вызове"""
    schema: type | UnionType
    # key: str

    @abc.abstractmethod
    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        pass

    @abc.abstractmethod
    def get_spec_parameters(self) -> set['SpecParameter']:
        pass


@dataclass(frozen=True, kw_only=True)
class LiteralParameter(CallParameter):
    value: Any

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        return self.value

    def get_spec_parameters(self) -> set['SpecParameter']:
        return set()


@dataclass(frozen=True, kw_only=True)
class SpecParameter(CallParameter, abc.ABC):
    """Параметр, входящий в спецификацию, location - расположение(хедер/query/path/т.п), name - имя в расположении"""
    location: ParameterLocation
    name: str
    is_optional: bool = False
    default: Any = NoValue

    @property
    def is_required(self) -> bool:
        return not self.is_optional

    def get_spec_parameters(self) -> set['SpecParameter']:
        return {self}

    def _cast_value_to_schema(self, value: Any):
        return convert_to_type(self.schema, self.name, value)

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        try:
            result = self._get(incoming_request)
        except KeyError:
            if self.default is NoValue:
                raise
            result = self.default
        return self._cast_value_to_schema(result)

    @abc.abstractmethod
    def _get(self, incoming_request: IncomingRequest) -> Any:
        pass


@dataclass(frozen=True, kw_only=True)
class QueryParameter(SpecParameter):
    location: ParameterLocation = ParameterLocation.query

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.url_query_key_to_value[self.name]


@dataclass(frozen=True, kw_only=True)
class HeaderParameter(SpecParameter):
    location: ParameterLocation = ParameterLocation.header

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.metadata.header_name_to_value[self.name]


@dataclass(frozen=True, kw_only=True)
class PathParameter(SpecParameter):
    is_optional: bool = field(init=False, default=False)
    location: ParameterLocation = ParameterLocation.path

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.path_match_key_to_value[self.name]


@dataclass(frozen=True, kw_only=True)
class BodyParameter(SpecParameter, abc.ABC):
    location: ParameterLocation = ParameterLocation.body
    body_mime_type: ClassVar[ContentType] = None


@dataclass(frozen=True, kw_only=True)
class JsonBodyParameter(BodyParameter):
    body_mime_type: ClassVar[ContentType] = ContentType.Json

    def _get(self, incoming_request: IncomingRequest) -> Any:
        raise_if(incoming_request.parsed_body is None, KeyError(self.name))
        return incoming_request.parsed_body[self.name]


@dataclass(frozen=True, kw_only=True)
class RawBodyParameter(BodyParameter):
    body_mime_type: ClassVar[ContentType] = ContentType.Octet
    name: str = field(init=False, default=None)
    schema: type = field(init=False, default=bytes)

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.payload

    def __class_getitem__(cls, content_type: ContentType) -> type:
        return create_type(_prepare_class_name([cls], content_type), [cls], {'body_mime_type': content_type})

# @dataclass(frozen=True, kw_only=True)
# class SecurityParameter(CallParameter):
#     name: str
#     security: Security
#
#     async def get(self, incoming_request: IncomingRequest) -> Any:
#         return self.security.validate(incoming_request, [self.name])


class ParameterAggregation(CallParameter, abc.ABC):
    @abc.abstractmethod
    def get_spec_parameters(self) -> set[SpecParameter]:
        pass


@dataclass(frozen=True, kw_only=True)
class CallTemplate(ParameterAggregation):
    func: Callable
    args: tuple[CallParameter, ...]
    kwargs: dict[str, CallParameter]

    def __init__(self, func: Callable, *args: CallParameter, **kwargs: CallParameter):
        object.__setattr__(self, 'func', func)
        object.__setattr__(self, 'args', args)
        object.__setattr__(self, 'kwargs', kwargs)

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        args = await asyncio.gather(*[arg.get(incoming_request, security_kwargs) for arg in self.args])
        kwarg_values = await asyncio.gather(*[arg.get(incoming_request, security_kwargs) for arg in self.kwargs.values()])
        kwargs = {key: value for key, value in zip(kwarg_values, self.kwargs)}
        result = await acall(self.func(*args, **kwargs))
        return result

    def get_spec_parameters(self) -> set[SpecParameter]:
        result = set()
        for arg in itertools.chain(self.args, self.kwargs.values()):
            result.union(arg.get_spec_parameters())
        return result


@dataclass(frozen=True, kw_only=True)
class SecurityParameter(CallParameter):
    """Параметр, чьё значение генерируется объектом Security"""
    name: str

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        value = security_kwargs[self.name]
        return convert_to_type(self.schema, self.name, value)

    def get_spec_parameters(self) -> set[SpecParameter]:
        return set()
