#  Copyright (C) 2024
#  ABM JSC, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Albakov Ruslan <r.albakov@abm-jsc.ru>
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Hashable

from redis_connector import RedisConnector
from abstract_semaphore import AbstractSemaphore

logger = logging.getLogger(__name__)

class RedisSemaphore(AbstractSemaphore):
    @dataclass
    class Context:
        redis: RedisConnector

    @dataclass
    class Config:
        key_time_to_live_s: int = 120
        sleep_time_s: int = 3

    def __init__(self, config: Config, context: Context) -> None:
        self.config = config
        self.context = context

    @asynccontextmanager
    async def restrict(self, key: Hashable, concurrency_limit: int = 1) -> None:
        current_time = str(time.time())
        redis_hash = str(key)
        redis_key = f'{redis_hash}_{current_time}'
        await self._check_limit_connection(redis_hash, redis_key, concurrency_limit)
        logger.debug(f"{redis_key} connection acquired")
        try:
            yield
        finally:
            logger.debug(f"{redis_key} connection released")
            await self.context.redis.hdel(redis_hash, redis_key)

    async def _check_limit_connection(self, redis_hash: str, redis_key: str, concurrency_limit: int) -> None:
        while await self.context.redis.hset_and_get_hlen(
                redis_hash, redis_key, seconds=self.config.key_time_to_live_s
        ) > concurrency_limit:
            await self.context.redis.hdel(redis_hash, redis_key)
            await asyncio.sleep(self.config.sleep_time_s)
