#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import abc
import contextvars
import logging
import time
import typing
from dataclasses import dataclass, field
from enum import StrEnum
from functools import cached_property
from types import MappingProxyType
from typing import Callable, Any, Iterable, TypeVar, Awaitable, TypeAlias, Literal, ClassVar

import aiohttp
from async_tools import acall
from authlib.jose import JsonWebKey, JsonWebToken, JWTClaims
from authlib.jose.errors import JoseError
from http_tools import IncomingRequest
from init_helpers import raise_if

from .extras import DataclassProtocol, url_to_snake_case
from .parameter import ParameterLocation, QueryParameter, PathParameter, HeaderParameter, SpecParameter

logger = logging.getLogger(__name__)

AuthToken: TypeAlias = str
AuthInfo = TypeVar("AuthInfo", bound=dict | DataclassProtocol)

resolver_cache = contextvars.ContextVar('resolver_cache')


class Unauthorized(Exception):
    pass


class Forbidden(Exception):
    pass


class AuthInfoException(Exception):
    pass


class SecuritySchemeType(StrEnum):
    api_key = "apiKey"
    http = "http"
    oauth2 = "oauth2"
    open_id_connect = "openIdConnect"
    mutual_tls = "mutualTLS"


@dataclass(frozen=True)
class SecurityScheme(abc.ABC, typing.Generic[AuthInfo]):
    type_: SecuritySchemeType
    resolver: Callable[[AuthToken], Awaitable[tuple[AuthInfo, Iterable[str]]] | tuple[AuthInfo, Iterable[str]]]
    do_log: bool = field(default=True, kw_only=True)

    @abc.abstractmethod
    def to_spec(self) -> dict:
        return {'type': self.type_}

    @abc.abstractmethod
    async def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        pass

    async def evaluate(self, incoming_request: IncomingRequest, required_scopes: frozenset[str]) -> AuthInfo:
        cache = resolver_cache.get({})
        if self not in cache:
            auth_token = await self._extract_token(incoming_request)
            try:
                cache[self] = await acall(self.resolver(auth_token))
            except Exception as e:
                cache[self] = e
        result = cache[self]
        if isinstance(result, Exception):
            raise result
        auth_info, allowed_scopes = cache[self]
        allowed_scopes = allowed_scopes if isinstance(allowed_scopes, set) else set(allowed_scopes)
        raise_if(missing_scopes := required_scopes - allowed_scopes,
                 Forbidden(f'missing scopes: {", ".join(sorted(missing_scopes))}'))
        return auth_info

    @property
    @abc.abstractmethod
    def key(self) -> str:
        pass

    def has(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'Security':
        return Security(self, scopes=args).use(**kwargs)


@dataclass(frozen=True)
class ApiKeySecurityScheme(SecurityScheme, typing.Generic[AuthInfo]):
    type_: SecuritySchemeType = field(init=False, default_factory=lambda: SecuritySchemeType.api_key)
    location: ParameterLocation
    name: str

    def __post_init__(self):
        assert self._parameter

    def to_spec(self) -> dict:
        return SecurityScheme.to_spec(self) | {'in': self.location, 'name': self.name}

    @cached_property
    def key(self) -> str:
        return f'{self.type_.name}_in_{self.location.name}_{self.name}'

    @cached_property
    def _parameter(self) -> SpecParameter:
        if self.location == ParameterLocation.query:
            return QueryParameter(name=self.name, schema=str)
        if self.location == ParameterLocation.path:
            return PathParameter(name=self.name, schema=str)
        if self.location == ParameterLocation.header:
            return HeaderParameter(name=self.name, schema=str)
        # elif self.in_ == ParameterLocation.cookie: TODO: implement
        #     return CookieParameter(name=self.name, schema=str)
        raise NotImplementedError(f"ApiKeySecurityScheme does not support location {self.location!r} ")

    async def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        try:
            return await self._parameter.get(incoming_request, {})
        except KeyError:
            raise Unauthorized from None


@dataclass(frozen=True)
class HttpSecurityScheme(SecurityScheme, abc.ABC, typing.Generic[AuthInfo]):
    type_: SecuritySchemeType = field(init=False, default_factory=lambda: SecuritySchemeType.http)
    header_scheme: str

    def to_spec(self) -> dict:
        return SecurityScheme.to_spec(self) | {'scheme': self.header_scheme}

    @cached_property
    def key(self) -> str:
        return f'{self.type_.name}_{self.header_scheme}'.strip()

    async def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        dirty_token: str = incoming_request.metadata.header_name_to_value.get("Authorization")
        raise_if(not isinstance(dirty_token, str), Unauthorized("Failed to get token from Authorization header"))
        raise_if(not dirty_token.startswith(self.header_scheme), Unauthorized("Wrong authorization scheme"))
        return dirty_token.removeprefix(self.header_scheme)


@dataclass(frozen=True)
class HttpBearerSecurityScheme(HttpSecurityScheme, typing.Generic[AuthInfo]):
    header_scheme: str = field(init=False, default='Bearer ')


@dataclass(frozen=True)
class HttpBasicSecurityScheme(HttpSecurityScheme, typing.Generic[AuthInfo]):
    header_scheme: str = field(init=False, default='Basic ')


class OAuth2Flow(StrEnum):
    implicit = "implicit"
    authorizationCode = "authorizationCode"
    password = "password"
    clientCredentials = "clientCredentials"


@dataclass(frozen=True)
class OAuth2SecurityScheme(SecurityScheme, abc.ABC, typing.Generic[AuthInfo]):
    type_: SecuritySchemeType = field(init=False, default_factory=lambda: SecuritySchemeType.oauth2)
    # flow_to_info: str ???

    def to_spec(self) -> dict:
        # return SecurityScheme.to_spec(self) | {'in': self.location, 'name': self.name}
        ...


@dataclass(frozen=True, kw_only=True)
class OpenIdConnectSecurityScheme(SecurityScheme, typing.Generic[AuthInfo]):
    url: str
    type_: SecuritySchemeType = field(init=False, default_factory=lambda: SecuritySchemeType.open_id_connect)
    resolver: Callable[[AuthToken], Awaitable[tuple[AuthInfo, Iterable[str]]] | tuple[AuthInfo, Iterable[str]]] = (
        field(init=False))
    header_scheme: ClassVar[str] = 'Bearer'
    expected_issuer: str = field(init=False)
    jwk_lifetime: ClassVar[float] = 60

    def __post_init__(self):
        well_known_index = self.url.find('/.well-known')
        raise_if(well_known_index == -1, ValueError('OpenIdConnectSecurityScheme URL requires ".well-known" part'))
        expected_issuer = self.url[:well_known_index]
        object.__setattr__(self, 'expected_issuer', expected_issuer)

        async def resolver(jwt_value: str) -> tuple[AuthInfo, frozenset[str]]:
            jwks = await self.get_jwks(time.time() - self.jwk_lifetime)
            jwt = JsonWebToken(['RS256'])
            try:
                claims: JWTClaims = jwt.decode(jwt_value, jwks)
                self._check_claims(claims)
            except (JoseError, ValueError) as e:
                raise Unauthorized(str(e)) from None

            return claims, self._get_scopes(claims)

        object.__setattr__(self, 'resolver', resolver)

    @staticmethod
    def _get_scopes(claims: JWTClaims) -> frozenset[str]:
        if (scope := claims.get('scope')) is None:
            raise Unauthorized(f'BadJWT: Missing scope')
        if not isinstance(scope, str):
            raise Unauthorized(f'BadJWT: Scope is not string: {type(scope).__name__}')
        scopes = frozenset(scope.split(' '))
        if 'openid' not in scopes:
            raise Unauthorized(f'BadJWT: Scope MUST contain openid, got: {scope!r}')

        return scopes

    def _check_claims(self, claims: JWTClaims) -> None:
        claims.validate()
        if (issuer := claims.get('iss')) != self.expected_issuer:
            raise Unauthorized(f'BadJWT: Invalid {issuer=}, expected: {self.expected_issuer!r}')
        # TODO: add "aud" check
        # if 'your-client-id' not in (audience := claims.get('aud', [])):
        #     raise Unauthorized(f'BadJWT: Invalid {audience=}')

    async def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        dirty_token: str = incoming_request.metadata.header_name_to_value.get("Authorization")
        raise_if(not isinstance(dirty_token, str), Unauthorized("Failed to get token from Authorization header"))
        raise_if(not dirty_token.startswith(self.header_scheme), Unauthorized("Wrong authorization scheme"))
        return dirty_token.removeprefix(self.header_scheme).strip()

    @property
    def key(self) -> str:
        return f'openid__{url_to_snake_case(self.expected_issuer)}'

    def to_spec(self) -> dict:
        return SecurityScheme.to_spec(self) | {'openIdConnectUrl': self.url}

    async def get_public_keys_url(self) -> str:
        if not hasattr(self, 'public_keys_url'):
            logger.debug(f'get_public_keys_url from {self.url}')
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.get(self.url, timeout=1) as resp:
                        jwks = await resp.json()
                object.__setattr__(self, 'public_keys_url', jwks['jwks_uri'])
            except Exception as e:
                raise RuntimeError("Failed to get public keys url") from e

        return getattr(self, 'public_keys_url')

    async def _load_public_keys(self) -> list[dict]:
        async with aiohttp.ClientSession() as session:
            logger.debug(f'get_public_keys_url from {await self.get_public_keys_url()}')
            async with session.get(await self.get_public_keys_url(), timeout=1) as resp:
                jwks = await resp.json()
        return jwks['keys']

    async def _refresh_public_keys(self) -> None:
        logger.debug('_refresh_public_keys')
        object.__setattr__(self, 'jwks', JsonWebKey.import_key_set({'keys': await self._load_public_keys()}))
        object.__setattr__(self, 'jwks_got_at', time.time())

    async def get_jwks(self, keys_newer_than: float):
        try:
            if getattr(self, 'jwks_got_at') > keys_newer_than:
                return getattr(self, 'jwks')
        except AttributeError:
            pass
        await self._refresh_public_keys()
        return getattr(self, 'jwks')


@dataclass(frozen=True, slots=True)
class Security:
    scheme: SecurityScheme
    scopes: frozenset[str] = field(default_factory=frozenset)
    argument_name_to_getter: MappingProxyType[str, Callable[[AuthInfo], Awaitable | Any]] = field(
        default_factory=dict, hash=False)

    def __init__(self, scheme: SecurityScheme, scopes: Iterable[str] = tuple(),
                 argument_name_to_getter: dict[str, Callable[[AuthInfo], Awaitable | Any]] | None = None):
        object.__setattr__(self, 'scheme', scheme)
        object.__setattr__(self, 'scopes', frozenset(scopes))
        object.__setattr__(self, 'argument_name_to_getter', MappingProxyType(argument_name_to_getter or {}))

    async def evaluate(self, incoming_request: IncomingRequest) -> dict:
        auth_info = await self.scheme.evaluate(incoming_request, self.scopes)
        try:
            return {name: await acall(getter(auth_info)) for name, getter in self.argument_name_to_getter.items()}
        except (KeyError, TypeError, ValueError) as e:
            raise AuthInfoException(*e.args) from e

    def use(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'Security':
        return self.has(*args, **kwargs)

    def has(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'Security':
        argument_name_to_getter = self.argument_name_to_getter.copy()
        for key, value in kwargs.items():
            if typing.get_origin(value) is Literal:
                args_args = typing.get_args(value)
                raise_if(len(args_args) != 1, TypeError(f'Only single literal values allowed, got: {args_args!r}'))
                argument_name_to_getter[key] = lambda _: args_args[0]
            elif callable(value):
                argument_name_to_getter[key] = value
            else:
                raise TypeError(f'Got unexpected value: {value}, Allowed: Literal/Callable[[AuthInfo], Any]')
        return Security(scheme=self.scheme, scopes=self.scopes | set(args),
                        argument_name_to_getter=argument_name_to_getter)


def get_securities_arg_names(securities: list[Security | SecurityScheme]) -> set[str]:
    result = None
    for security in securities:
        names = set(security.argument_name_to_getter.keys()) if isinstance(security, Security) else set()
        if result is None:
            result = names
        else:
            raise_if(names != result, ValueError(f'All {securities=} must provide same arguments: {names} != {result}'))
    return result or set()
