#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mikhail Mamrov <m.mamrov@abm-jsc.ru>
#
import logging
from typing import ClassVar, TypeAlias
import jwt
from jwt import DecodeError, ExpiredSignatureError

from init_helpers.dict_to_dataclass import dict_to_dataclass

from connector_library.auth_server_connector.auth_server_connector import AuthServerConnector, JWTPayload, TokenType

logger = logging.getLogger(__name__)


class InvalidTokenError(Exception):
    pass


class ExpiredTokenError(Exception):
    pass


JWT: TypeAlias = str


class TokenValidator:
    _SIGNATURE_ALGORITHM: ClassVar[str] = "RS256"

    def __init__(self, auth_server_connector: AuthServerConnector) -> None:
        self._auth_server_connector = auth_server_connector
        self._public_key = None
        logger.info(f"{type(self).__name__} inited")

    def validate_access_token(self, token: str, origin: str) -> JWTPayload:
        if token is None:
            raise InvalidTokenError("No access token")

        if self._public_key is None:
            self._public_key = self._auth_server_connector.get_public_key()

        try:
            payload = jwt.decode(token, self._public_key, algorithms=[self._SIGNATURE_ALGORITHM])
        except DecodeError:
            raise InvalidTokenError("Invalid token")
        except ExpiredSignatureError:
            raise ExpiredTokenError(f"{TokenType.ACCESS_TOKEN} lifetime ended")

        jwt_payload = dict_to_dataclass(payload, JWTPayload)

        if jwt_payload.type != TokenType.ACCESS_TOKEN:
            raise InvalidTokenError(f"Wrong token type, expected {TokenType.ACCESS_TOKEN}, got {jwt_payload.type}")

        if jwt_payload.portal.origin != origin:
            raise InvalidTokenError(f"Token doesn't belong to origin == {origin}")

        return jwt_payload
