import abc
import dataclasses
import itertools
import logging
import traceback
from abc import ABC
from dataclasses import dataclass
from typing import Union, Callable, Any, Type

from async_tools import acall
from dict_caster import DictCaster, Item
from dict_caster.extras import first
from http_tools import Answer, HttpServer
from http_tools.request import IncomingRequest
from init_helpers.dict_to_dataclass import dict_to_dataclass, NoValue

from .open_api_wrapper import OpenApiWrapper, Parameter, Endpoint, Security, RequestBody, Content, ParameterLocation

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class JsonBodyParameter(Parameter):
    in_: ParameterLocation = ParameterLocation.body
    body_mime_type = "application/json"


@dataclass(frozen=True)
class RawBodyParameter(Parameter):
    body_mime_type: str = "application/octet-stream"
    name: str = None
    schema: Type = bytes
    required: bool = True
    in_: ParameterLocation = ParameterLocation.body


class ParameterAggregation(ABC):
    @abc.abstractmethod
    def get_parameters(self) -> set[Parameter]:
        pass

    @abc.abstractmethod
    async def execute(self, context: dict) -> Any:
        pass


@dataclass
class RpcEndpoint:
    callable: Callable
    securities: list[Security]
    argument_name_to_parameter: dict[str, Union[Parameter, ParameterAggregation]]
    answer_type: Type[Answer]
    exception_type_to_answer_type: dict[Type[Exception], Type[Answer]]
    unhandled_exception__answer_type: Type[Answer] | None = None
    operation_id: str | None = None

    def get_parameters(self) -> set[Parameter]:
        result = set()
        for parameter in self.argument_name_to_parameter.values():
            if isinstance(parameter, ParameterAggregation):
                result.update(parameter.get_parameters())
            elif isinstance(parameter, Parameter):
                result.add(parameter)
            else:
                raise TypeError(f"Unexpected parameter type: {type(parameter)}")
        return result

    async def prepare_parameters(self, parameter_name_to_value: dict[str, Any]) -> dict[str, Any]:
        result = {}
        for argument_name, parameter in self.argument_name_to_parameter.items():
            if isinstance(parameter, ParameterAggregation):
                result[argument_name] = await parameter.execute(parameter_name_to_value)
            elif isinstance(parameter, Parameter):
                result[argument_name] = parameter_name_to_value[parameter.name]
            else:
                raise TypeError(f"Unexpected parameter type: {type(parameter)}")
        return result


class OpenApiServer:
    @dataclass
    class Context(OpenApiWrapper.Context):
        http_server: HttpServer

    @dataclass
    class Config(OpenApiWrapper.Config):
        pass

    def __init__(self, config: Config, context: Context):
        self.context = context
        self.config = config
        self.wrapper = OpenApiWrapper(config, context)

    def register_handler(self, method: str, path: str, rpc_endpoint: RpcEndpoint) -> None:
        operation_id = rpc_endpoint.operation_id or rpc_endpoint.callable.__name__
        request_body = None
        body_parameter_type_to_parameters: dict[Type[Parameter], list[Parameter]] = {}
        non_body_parameters = []
        for param in rpc_endpoint.get_parameters():
            if param.body_mime_type is None:
                non_body_parameters.append(param)
            else:
                body_parameter_type_to_parameters.setdefault(type(param), []).append(param)

        if len(body_parameter_type_to_parameters) > 1:
            raise TypeError(f"Can not combine {list(body_parameter_type_to_parameters)} in one body")

        request_body_dataclass_ = None
        body_parameter_type = None
        if body_parameter_type_to_parameters:
            body_parameter_type = first(body_parameter_type_to_parameters)
            body_parameters = body_parameter_type_to_parameters[body_parameter_type]
            if issubclass(body_parameter_type, RawBodyParameter):
                mime_type = None
                for param in body_parameters:
                    if mime_type is not None and mime_type != param.body_mime_type:
                        raise TypeError(
                            f"Multiple raw body parameters have different mime types: {mime_type}, "
                            f"{param.body_mime_type}"
                        )
                    mime_type = param.body_mime_type
                request_body = RequestBody(Content(mime_type, bytes))
            else:
                sorted_body_parameters = sorted(body_parameters, key=lambda x: x.default is not NoValue)
                # TODO: multiple endpoints with different body signature will break here
                dataclass_dict = {
                    param.name: dataclasses.field(
                        default=param.default if param.default is not NoValue else dataclasses.MISSING
                    )
                    # param.name: dataclasses.field(default=param.default)
                    for param in sorted_body_parameters
                }
                class_ = type(operation_id, tuple(), dataclass_dict)
                class_.__annotations__ = {param.name: param.schema for param in sorted_body_parameters}
                # noinspection PyTypeChecker
                request_body_dataclass_ = dataclasses.dataclass(class_)
                request_body = RequestBody(Content(body_parameter_type.body_mime_type, request_body_dataclass_))
        code_to_answer = {rpc_endpoint.answer_type.get_class_status_code(): rpc_endpoint.answer_type}
        code_to_answer.update({
            answer_type.get_class_status_code(): answer_type
            for answer_type in rpc_endpoint.exception_type_to_answer_type.values()
        })
        if rpc_endpoint.unhandled_exception__answer_type is not None:
            unhandled_exception__answer_code = rpc_endpoint.unhandled_exception__answer_type.get_class_status_code()
            code_to_answer[unhandled_exception__answer_code] = rpc_endpoint.unhandled_exception__answer_type
        self.wrapper.register_endpoint(Endpoint(
            path=path, method=method, operation_id=operation_id, securities=rpc_endpoint.securities,
            parameters=non_body_parameters,
            request_body=request_body,
            code_to_answer=code_to_answer
        ))

        async def handler(request: IncomingRequest) -> Answer:
            # TODO: implement securities
            raw_kwargs = {}
            parameter_location_to_items = {}
            for parameter in rpc_endpoint.get_parameters():
                item = Item(parameter.name, parameter.schema)
                if parameter.default != NoValue:
                    item.default = parameter.default
                parameter_location_to_items.setdefault(parameter.in_, []).append(item)

            for parameter_location, items in parameter_location_to_items.items():
                if parameter_location == ParameterLocation.path:
                    raw_kwargs.update(DictCaster(items).cast(request.path_match_key_to_value))
                if parameter_location == ParameterLocation.query:
                    raw_kwargs.update(DictCaster(items).cast(request.url_query_key_to_value))
                if parameter_location == ParameterLocation.header:
                    raw_kwargs.update(DictCaster(items).cast(request.metadata.header_name_to_value))
                if parameter_location == ParameterLocation.body:
                    pass  # processed later
                if parameter_location == ParameterLocation.cookie:
                    raise NotImplemented("Cookie parameters aro not supported")

            result = None
            try:
                if body_parameter_type and issubclass(body_parameter_type, RawBodyParameter):
                    raw_kwargs[None] = request.payload
                elif request_body_dataclass_ is not None:
                    # TODO: collect extra data and put it to "warning" in answer
                    dc = dict_to_dataclass(request.parsed_body, request_body_dataclass_)

                    for parameter in body_parameters:
                        val = getattr(dc, parameter.name)
                        raw_kwargs[parameter.name] = val

                kwargs = await rpc_endpoint.prepare_parameters(raw_kwargs)
                call_result = await acall(rpc_endpoint.callable(**kwargs))
                result = rpc_endpoint.answer_type(call_result)
            except Exception as e:
                # traceback.print_exc()
                for exception_type, answer_type in rpc_endpoint.exception_type_to_answer_type.items():
                    if isinstance(e, exception_type):
                        result = answer_type(e)  # TODO: basic answer will break here
                        break

                if result is None:
                    if rpc_endpoint.unhandled_exception__answer_type is not None:
                        logger.exception(e)
                        result = rpc_endpoint.unhandled_exception__answer_type(e)  # TODO: basic answer will break here
                    else:
                        raise

            return result

        self.context.http_server.register_handler(path, handler)


class CallTemplate(ParameterAggregation):
    def __init__(self, callable_: Callable, *args: Parameter, **kwargs: Parameter):
        self.callable = callable_
        self.args = args
        self.kwargs = kwargs

    async def execute(self, context: dict) -> Any:
        args = []
        for parameter in self.args:
            if isinstance(parameter, ParameterAggregation):
                args.append(parameter.execute(context))
            elif isinstance(parameter, Parameter):
                args.append(context[parameter.name])
            else:
                raise TypeError(f"Unexpected parameter type: {type(parameter)}")

        kwargs = {}
        for name, parameter in self.kwargs.items():
            if isinstance(parameter, ParameterAggregation):
                kwargs[name] = parameter.execute(context)
            elif isinstance(parameter, Parameter):
                kwargs[name] = context[parameter.name]
            else:
                raise TypeError(f"Unexpected parameter type: {type(parameter)}")

        result = await acall(self.callable(*args, **kwargs))
        return result

    def get_parameters(self) -> set[Parameter]:
        result = set()
        for arg in itertools.chain(self.args, self.kwargs.values()):
            if isinstance(arg, ParameterAggregation):
                result.union(arg.get_parameters())
            elif isinstance(arg, Parameter):
                result.add(arg)
            else:
                raise TypeError(f"Unexpected arg type: {type(arg)}")
        return result
