#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import abc
import asyncio
import itertools
import operator
import typing
from dataclasses import dataclass, field
from enum import StrEnum
from types import UnionType
from typing import Any, ClassVar, Callable, Annotated, Iterable

from async_tools import acall
from dict_caster.extras import is_iterable
from dynamic_types.class_name import prepare_class_name
from dynamic_types.create import create_type
from http_tools import IncomingRequest
from http_tools.mime_types import ContentType
from http_tools.multipart_form_data import BodyPart
from init_helpers import raise_if, append_if
from init_helpers.dict_to_dataclass import NoValue, convert_to_type, FieldErrors
from multidict import CIMultiDictProxy

from .example import BaseExample


class ParameterLocation(StrEnum):
    query = 'query'
    path = 'path'
    header = 'header'
    cookie = 'cookie'
    body = 'body'  # INNER: not present in specification, used inside library
    security = 'security'  # INNER: not present in specification, used inside library
    literal = 'literal'  # INNER: not present in specification, used inside library


class ParameterStyle(StrEnum):
    """
    Example with values
    string = "blue"
    array  = ["blue","black","brown"]
    object = {"R": 100, "G": 200, "B": 150}
    """

    matrix = 'matrix'                   
    # explode=False     ;color      ;color=blue     ;color=blue,black,brown                 ;color=R,100,G,200,B,150
    # explode=True      ;color      ;color=blue     ;color=blue;color=black;color=brown     ;R=100;G=200;B=150
    label = 'label'
    # explode=False     .           .blue           .blue.black.brown                       .R.100.G.200.B.150
    # explode=True      .           .blue           .blue.black.brown                       .R=100.G=200.B=150
    form = 'form'                       
    # explode=False     color=      color=blue      color=blue,black,brown                  color=R,100,G,200,B,150
    # explode=True      color=      color=blue      color=blue&color=black&color=brown      R=100&G=200&B=150
    simple = 'simple'                   
    # explode=False     n/a         blue            blue,black,brown                        R,100,G,200,B,150
    # explode=True      n/a         blue            blue,black,brown                        R=100,G=200,B=150
    spaceDelimited = 'spaceDelimited'   
    # explode=False     n/a         n/a   	        blue%20black%20brown                    R%20100%20G%20200%20B%20150
    pipeDelimited = 'pipeDelimited'  
    # explode=False     n/a         n/a             blue|black|brown                        R|100|G|200|B|150
    deepObject = 'deepObject'  
    # explode=True      n/a         n/a             n/a                         color[R]=100&color[G]=200&color[B]=150


@dataclass(frozen=True, kw_only=True)
class CallParameter(abc.ABC):
    """Параметр вызова процедуры, schema - тип, key - ключ при вызове"""
    schema: type | UnionType
    # key: str

    @abc.abstractmethod
    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        pass

    @abc.abstractmethod
    def get_spec_parameters(self) -> set['SpecParameter']:
        pass


@dataclass(frozen=True, kw_only=True)
class LiteralParameter(CallParameter):
    value: Any

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        return self.value

    def get_spec_parameters(self) -> set['SpecParameter']:
        return set()


@dataclass(frozen=True, kw_only=True)
class SpecParameter(CallParameter, abc.ABC):
    """Параметр, входящий в спецификацию, location - расположение(хедер/query/path/т.п), name - имя в расположении"""
    location: ParameterLocation
    name: str
    is_optional: bool = False
    default: Any = NoValue
    description: str | NoValue = NoValue
    examples: tuple[BaseExample, ...] | NoValue = NoValue
    explode: bool | None = None
    style: ParameterStyle | None = None

    @property
    def annotated_schema(self) -> type | UnionType | Annotated:
        notes = []
        append_if(self.description is not NoValue, notes, self.description)
        append_if(self.examples is not NoValue, notes, *self.examples if self.examples is not NoValue else [])
        return Annotated[self.schema, *notes] if notes else self.schema

    @property
    def is_required(self) -> bool:
        return self.default is NoValue and not self.is_optional

    def get_spec_parameters(self) -> set['SpecParameter']:
        return {self}

    def _cast_value_to_schema(self, value: Any):
        if self.explode:
            type_args = typing.get_args(self.schema)
            if len(type_args) != 1:
                raise TypeError(f'Only generic containers with single argument (list, set, simple tuple) are supported')
            type_arg = type_args[0]
            return [self._cast_single_value(value=val, schema=type_arg, name=self.name) for val in value]
        return self._cast_single_value(value=value, schema=self.schema, name=self.name)

    def _cast_single_value(self, value: Any, schema: type, name: str):
        if self.style is not None and isinstance(value, str):
            value = self._apply_style(value)
        return convert_to_type(value=value, field_type=schema, field_name=name)

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any | NoValue:
        try:
            result = self._get(incoming_request)
        except KeyError:
            if self.default is NoValue:
                if self.is_optional:
                    return NoValue
                raise
            result = self.default
        return self._cast_value_to_schema(result)

    @abc.abstractmethod
    def _get(self, incoming_request: IncomingRequest) -> Any:
        pass

    def _apply_style(self, value: str) -> Any:
        if self.style == ParameterStyle.simple:
            assert isinstance(value, str), "Expected string value, got:"
            return value.split(",")
        return value


@dataclass(frozen=True, kw_only=True)
class QueryParameter(SpecParameter):
    location: ParameterLocation = ParameterLocation.query

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.url_query_key_to_value[self.name]


@dataclass(frozen=True, kw_only=True)
class HeaderParameter(SpecParameter):
    location: ParameterLocation = ParameterLocation.header

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.metadata.header_name_to_value[self.name]


@dataclass(frozen=True, kw_only=True)
class PathParameter(SpecParameter):
    is_optional: bool = field(init=False, default=False)
    location: ParameterLocation = ParameterLocation.path

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.path_match_key_to_value[self.name]


@dataclass(frozen=True, kw_only=True)
class BodyParameter(SpecParameter, abc.ABC):
    location: ParameterLocation = ParameterLocation.body
    body_mime_type: ClassVar[ContentType] = None


@dataclass(frozen=True, kw_only=True)
class JsonWholeBodyParameter(BodyParameter, abc.ABC):
    body_mime_type: ClassVar[ContentType] = ContentType.Json

    def __init__(self):
        raise NotImplementedError("requires inheritance restructuration")


@dataclass(frozen=True, kw_only=True)
class JsonBodyParameter(BodyParameter):
    body_mime_type: ClassVar[ContentType] = ContentType.Json

    def _get(self, incoming_request: IncomingRequest) -> Any:
        raise_if(incoming_request.parsed_body is None, KeyError(self.name))
        return incoming_request.parsed_body[self.name]


@dataclass(frozen=True)
class ParsedBodyPart:
    filename: str | None
    header_name_to_value: dict[str, Any]
    payload: Any
    encoding: str | None
    content_type: ContentType | None


class MultipartParameterHeaders:
    def __init__(self, parent: 'MultipartParameter') -> None:
        self.parent = parent
        self._parent_header_call = CallTemplate(operator.attrgetter("header_name_to_value"), self.parent)

    def __getitem__(self, item: str):
        raise_if(not isinstance(item, str), TypeError(f'Header names must be string, got: {type(item)}: {item}'))
        return CallTemplate(operator.itemgetter(item), self._parent_header_call)

    def __str__(self):
        return f'MultipartParameterHeaders(parent={self.parent})'

    __repr__ = __str__


@dataclass(frozen=True, kw_only=True)
class MultipartParameter(BodyParameter):
    schema: type | UnionType
    body_mime_type: ClassVar[ContentType] = ContentType.MultipartForm
    allowed_content_type: ContentType | tuple[ContentType, ...] | None = None
    headers: tuple[HeaderParameter, ...] = field(init=False)
    unpack: bool = True
    explode: bool | None = None
    raw: bool = False

    def __init__(
            self,
            name: str,
            schema: type | UnionType = bytes,
            is_optional: bool = False,
            default: Any = NoValue,
            description: str | NoValue = NoValue,
            examples: tuple[BaseExample, ...] | NoValue = NoValue,
            allowed_content_type: ContentType | Iterable[ContentType] | None = None,
            headers: Iterable[HeaderParameter] | None = None,
            unpack: bool = True,
            explode: bool | None = None,
            is_filename_required: bool = False,
            style: ParameterStyle | None = None,
            raw: False = False,
    ) -> None:
        if not unpack and default is not NoValue:
            raise_if(not isinstance(default, BodyPart),
                     TypeError(f'With unpack=False default MUST be BodyPart, got {default=}'))
        if explode and typing.get_origin(schema) not in (list, tuple, set):
            raise TypeError(f'explode=True can be used only with iterable schema, got {schema=}')
        object.__setattr__(self, 'name', name)
        object.__setattr__(self, 'schema', schema)
        object.__setattr__(self, 'is_optional', is_optional)
        object.__setattr__(self, 'default', default)
        object.__setattr__(self, 'description', description)
        object.__setattr__(self, 'examples', examples)
        object.__setattr__(self, 'raw', raw)

        if allowed_content_type:
            if is_iterable(allowed_content_type):
                # noinspection PyUnresolvedReferences
                types_tuple = tuple(t.value for t in allowed_content_type)
                object.__setattr__(self, 'allowed_content_type', types_tuple)
                object.__setattr__(
                    self, '_check_content_type',
                    lambda x: raise_if(x not in types_tuple, KeyError(f'Expected {types_tuple}, got: {x!r}')))
            else:
                object.__setattr__(self, 'allowed_content_type', allowed_content_type)
                object.__setattr__(
                    self, '_check_content_type',
                    lambda x: raise_if(x != allowed_content_type, KeyError(f'Expected {allowed_content_type}, got: {x}')))
        else:
            object.__setattr__(self, 'allowed_content_type', None)

        if headers:
            object.__setattr__(self, 'headers', headers_tuple := tuple(headers))
            raise_if(
                len({h.name for h in headers_tuple}) < len(headers_tuple),
                KeyError('Got header parameters with ')
            )
            # object.__setattr__(
            #     self, '_check_headers',
            #     lambda x: raise_if(x not in content_type, KeyError(f'Expected {content_type}, got: {x}')))
        else:
            object.__setattr__(self, 'headers', tuple())
        object.__setattr__(self, 'unpack', unpack)
        object.__setattr__(self, 'explode', explode)
        object.__setattr__(self, 'is_filename_required', is_filename_required)
        object.__setattr__(self, 'style', style)

    def _check_content_type(self, content_type: str | None) -> None:
        pass

    def _check_headers(self, header_name_to_value: CIMultiDictProxy[str]) -> None:
        pass

    def _get(self, incoming_request: IncomingRequest) -> BodyPart | Any:
        raise_if(incoming_request.parsed_body is None, KeyError(self.name))
        raise_if(not isinstance(incoming_request.parsed_body, CIMultiDictProxy),
                 TypeError('Multipart parameter got non multipart body'))
        body: CIMultiDictProxy = incoming_request.parsed_body
        result = body.getall(self.name)
        for body_part in result:
            self._check_content_type(body_part.content_type)
        if not self.explode:
            if len(result) != 1:
                raise ValueError(f"Multipart got {len(result)} values, expected 1")
            result = result[0]

        return result

    def _extract_headers(self, value: BodyPart) -> dict[str, Any]:
        result: dict[str, Any] = {}
        field_name_to_error: dict[str, Exception] = {}
        for header in self.headers:
            if (val := value.get_all_header_values(header.name, header.default)) == NoValue and header.is_required:
                field_name_to_error[header.name] = KeyError("Missing multipart header")
                continue
            if not header.explode:
                if len(val) != 1:
                    field_name_to_error[header.name] = ValueError(f"Multipart header got {len(val)} values, expected 1")
                    continue
                val = val[0]
            result[header.name] = header._cast_value_to_schema(val)

        if field_name_to_error:
            raise FieldErrors(field_name_to_error)
        return result

    def _cast_single_value(self, value: BodyPart, schema: type, name: str) -> ParsedBodyPart:
        if self.unpack:
            if isinstance(value, BodyPart):
                value = value.content if self.raw else value.parse()
            return super()._cast_single_value(value=value, schema=schema, name=name)
        if not isinstance(value, BodyPart):
            raise TypeError(f'UNEXPECTED: Non unpack cast got non BodyPart {value=}')
        return ParsedBodyPart(
            filename=value.filename,
            header_name_to_value=self._extract_headers(value),
            payload=super()._cast_single_value(
                value=value.content if self.raw else value.parse(), schema=schema, name=name),
            encoding=value.encoding,
            content_type=value.content_type
        )

    def _cast_value_to_schema(self, value: Any):
        if self.unpack:
            if isinstance(value, BodyPart):
                return super()._cast_value_to_schema(value=value.parse())
            return super()._cast_value_to_schema(value=value)
        return super()._cast_value_to_schema(value=value)

    @property
    def payload(self) -> 'CallTemplate[bytes]':
        return CallTemplate(operator.attrgetter("payload"), self)

    @property
    def filename(self) -> 'CallTemplate[str | None]':
        return CallTemplate(operator.attrgetter("filename"), self)

    @property
    def content_type(self) -> 'CallTemplate[ContentType | None]':
        return CallTemplate(operator.attrgetter("content_type"), self)

    @property
    def header(self) -> MultipartParameterHeaders:
        return MultipartParameterHeaders(self)


@dataclass(frozen=True, kw_only=True)
class RawBodyParameter(BodyParameter):
    body_mime_type: ClassVar[ContentType | tuple[ContentType, ...]] = ContentType.Octet
    name: str = field(init=False, default=None)
    schema: type = field(init=False, default=bytes)

    def _get(self, incoming_request: IncomingRequest) -> Any:
        return incoming_request.payload

    def __class_getitem__(cls, content_type: ContentType | Iterable[ContentType]) -> type:
        if is_iterable(content_type):
            content_type = tuple(content_type)
        return create_type(prepare_class_name([cls], content_type), [cls], {'body_mime_type': content_type})


class ParameterAggregation(CallParameter, abc.ABC):
    @abc.abstractmethod
    def get_spec_parameters(self) -> set[SpecParameter]:
        pass


@dataclass(frozen=True, kw_only=True)
class CallTemplate(ParameterAggregation):
    func: Callable
    args: tuple[CallParameter, ...]
    kwargs: dict[str, CallParameter]

    def __init__(self, func: Callable, *args: CallParameter, **kwargs: CallParameter):
        object.__setattr__(self, 'func', func)
        object.__setattr__(self, 'args', args)
        object.__setattr__(self, 'kwargs', kwargs)

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        args = await asyncio.gather(*[arg.get(incoming_request, security_kwargs) for arg in self.args])
        kwarg_values = await asyncio.gather(*[v.get(incoming_request, security_kwargs) for v in self.kwargs.values()])
        args = [value for value in args if value is not NoValue]
        kwargs = {key: value for key, value in zip(kwarg_values, self.kwargs) if value is not NoValue}
        result = await acall(self.func(*args, **kwargs))
        return result

    def get_spec_parameters(self) -> set[SpecParameter]:
        result = set()
        for arg in itertools.chain(self.args, self.kwargs.values()):
            result = result.union(arg.get_spec_parameters())
        return result


@dataclass(frozen=True, kw_only=True)
class SecurityParameter(CallParameter):
    """Параметр, чьё значение генерируется объектом Security"""
    name: str

    async def get(self, incoming_request: IncomingRequest, security_kwargs: dict[str, Any]) -> Any:
        value = security_kwargs[self.name]
        return convert_to_type(self.schema, self.name, value)

    def get_spec_parameters(self) -> set[SpecParameter]:
        return set()
