#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Alexander Medvedev <a.medvedev@abm-jsc.ru>

import hashlib
import logging
import os
import time
from enum import unique, Enum
from typing import Optional

from aiohttp import BasicAuth, hdrs, ClientResponse
from aiohttp.web_exceptions import HTTPUnauthorized
from init_helpers import conditional_set
from http_tools.calculate_hash import HashAlgorithm, calculate_hash
from http_tools.charset import Charset

logger = logging.getLogger(__name__)


def to_bytes(value: str, charset: Optional[Charset] = Charset.UTF8) -> bytes:
    return value.encode(charset)


@unique
class DigestHeaderKey(str, Enum):
    realm = 'realm'
    nonce = 'nonce'
    qop = 'qop'
    algorithm = 'algorithm'
    opaque = 'opaque'


class DigestAuth(BasicAuth):
    _DIGEST_PREFIX = 'Digest '

    def __init__(self, username: str, password: str) -> None:
        self._username = username
        self._password = password

        self._last_nonce = ''
        self._nonce_count = 0

        self._response: Optional[ClientResponse] = None

        self.is_first_request = True

    def _parse_digest_header(self) -> dict[str, str]:
        header_value = self._response.headers.get(hdrs.WWW_AUTHENTICATE)
        if header_value is None or not header_value.startswith(self._DIGEST_PREFIX):
            logger.error(f'Invalid header: {header_value}')
            raise HTTPUnauthorized()
        header_value = header_value[len(self._DIGEST_PREFIX):]

        result = {}
        for key_to_val in header_value.split(', '):
            key, val = key_to_val.split('=')
            result[key] = val.replace('"', '')
        return result

    def set_response_context(self, response: ClientResponse) -> None:
        self._response = response
        self.is_first_request = False

    def encode(self) -> str:
        """Encode credentials."""
        return self.build_digest_header() if not self.is_first_request else ''

    def build_digest_header(self) -> Optional[str]:
        header_key_to_val = self._parse_digest_header()

        realm = header_key_to_val[DigestHeaderKey.realm]
        nonce = header_key_to_val[DigestHeaderKey.nonce]
        qop = header_key_to_val.get(DigestHeaderKey.qop)
        algorithm = header_key_to_val.get(DigestHeaderKey.algorithm, HashAlgorithm.md5).upper()
        opaque = header_key_to_val.get(DigestHeaderKey.opaque)

        a_1 = f'{self._username}:{realm}:{self._password}'
        a_2 = f'{self._response.method}:{self._response.url.path_qs}'

        hash_algorithm = HashAlgorithm(algorithm)
        ha_1 = calculate_hash(to_bytes(a_1), hash_algorithm)
        ha_2 = calculate_hash(to_bytes(a_2), hash_algorithm)

        self._nonce_count = self._nonce_count + 1 if nonce == self._last_nonce else 1
        self._last_nonce = nonce

        cnonce = b''.join([to_bytes(str(self._nonce_count)), to_bytes(nonce), to_bytes(time.ctime()), os.urandom(8)])
        hashed_cnonce = hashlib.sha1(cnonce).hexdigest()[:16]

        if hash_algorithm == HashAlgorithm.md5_sess:
            ha_1 = calculate_hash(to_bytes(f'{ha_1}:{nonce}:{hashed_cnonce}'), hash_algorithm)

        nc_value = f'{self._nonce_count:08x}'
        digest_auth = 'auth'
        if not qop:
            hashed_response = calculate_hash(to_bytes(f'{ha_1}:{{{nonce}:{ha_2}}}'), hash_algorithm)
        elif qop == digest_auth or digest_auth in qop.split(','):
            nonce_bit = f'{nonce}:{nc_value}:{hashed_cnonce}:{digest_auth}:{ha_2}'
            hashed_response = calculate_hash(to_bytes(f'{ha_1}:{nonce_bit}'), hash_algorithm)
        else:
            logger.error(f'Digest header building failed. Incoming qop: {qop} is not correct')
            raise HTTPUnauthorized()

        key_to_value = {}
        conditional_set(key_to_value,
                        username=self._username,
                        realm=realm,
                        nonce=nonce,
                        uri=self._response.url.path_qs,
                        response=hashed_response,
                        opaque=opaque,
                        algorithm=hash_algorithm)
        digest_header_value = ', '.join([f'{key}="{value}"' for key, value in key_to_value.items()])
        if qop:
            digest_header_value += f', qop="{digest_auth}", nc={nc_value}, cnonce="{hashed_cnonce}"'

        return f'{self._DIGEST_PREFIX}{digest_header_value}'
