#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import asyncio
from dataclasses import dataclass, field
from typing import Generic, Callable, TypeVarTuple, Unpack, Any, Awaitable, Literal, \
    get_origin, get_args
from typing import Mapping, Iterable

from async_tools import acall
from frozendict import frozendict
from http_tools import IncomingRequest
from init_helpers import Jsonable, raise_if

from openapi_tools.spec import SpecResource
from ..exceptions import AuthInfoException
from ..scheme import SecurityScheme, AuthInfo

Ts = TypeVarTuple('Ts')


@dataclass(frozen=True, slots=True)
class SecurityRequirement(SpecResource, Generic[Unpack[Ts]]):
    scheme_to_scopes: Mapping[SecurityScheme, frozenset[str]] = field(default_factory=frozendict)
    argument_name_to_getter: dict[str, Callable[[Unpack[Ts]], Any]] = field(default_factory=dict, hash=False)
    string_key: str = field(init=False)
    do_log: bool = field(init=False)

    def __init__(self, scheme_to_scopes: Mapping[SecurityScheme, Iterable[str]] | None = None,
                 argument_name_to_getter: Mapping[str, Callable[[AuthInfo], Awaitable | Any]] | None = None):
        object.__setattr__(
            self, 'scheme_to_scopes',
            frozendict({scheme: frozenset(scopes) for scheme, scopes in (scheme_to_scopes or {}).items()}))
        object.__setattr__(self, 'argument_name_to_getter', frozendict(argument_name_to_getter or {}))
        object.__setattr__(self, 'string_key', ';'.join(
            f'{scheme.get_key()}[{",".join(sorted(scopes))}]' for scheme, scopes in self.scheme_to_scopes.items()
        ))
        object.__setattr__(self, 'do_log', any(s.do_log for s in self.scheme_to_scopes))

    def get_spec_dependencies(self) -> frozenset['SpecResource']:
        return frozenset(self.scheme_to_scopes.keys())

    def get_spec_dict(self, dependency_to_ref: Mapping['SpecResource', str]) -> frozendict[str, Jsonable]:
        # noinspection PyTypeChecker
        return frozendict({dependency_to_ref[scheme]: scopes for scheme, scopes in self.scheme_to_scopes.items()})

    async def evaluate(self, incoming_request: IncomingRequest) -> dict:
        async with asyncio.TaskGroup() as task_group:
            tasks = [
                task_group.create_task(scheme.evaluate(incoming_request, scopes))
                for scheme, scopes in self.scheme_to_scopes.items()
            ]
        auth_infos = [t.result() for t in tasks]
        try:
            return {name: await acall(getter(*auth_infos)) for name, getter in self.argument_name_to_getter.items()}
        except (KeyError, TypeError, ValueError) as e:
            raise AuthInfoException(*e.args) from e

    def use(self, **kwargs: type[Literal] | Callable[[Unpack[Ts]], Any]) -> 'SecurityRequirement':
        argument_name_to_getter = dict(self.argument_name_to_getter)
        for key, value in kwargs.items():
            if get_origin(value) is Literal:
                args_args = get_args(value)
                raise_if(len(args_args) != 1, TypeError(f'Only single literal values allowed, got: {args_args!r}'))
                argument_name_to_getter[key] = lambda _: args_args[0]
            elif callable(value):
                argument_name_to_getter[key] = value
            else:
                raise TypeError(f'Got unexpected value: {value}, Allowed: Literal/Callable[[AuthInfo], Any]')
        return SecurityRequirement(scheme_to_scopes=self.scheme_to_scopes,
                                   argument_name_to_getter=argument_name_to_getter)

    def has(self, **kwargs: type[Literal] | Callable[[Unpack[Ts]], Any]) -> 'SecurityRequirement':
        return self.use(**kwargs)

    def _repr(self) -> list[str]:
        parts = []
        if self.scheme_to_scopes:
            parts.append(f'scheme_to_scopes={ {scheme: tuple(scopes) for scheme, scopes in self.scheme_to_scopes.items()} }')
        return parts

    def __repr__(self):
        return f'{self.__class__.__name__}({", ".join(self._repr())})'

    __str__ = __repr__
