# Copyright (C) 2021
# ABM, Moscow

# UNPUBLISHED PROPRIETARY MATERIAL.
# ALL RIGHTS RESERVED.

# Authors: Vasiliev Ivan <i.vasiliev@abm-jsc.ru>
import asyncio
import logging
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from enum import IntEnum, unique
from typing import Any, Optional, Callable, Dict,  Union, List

from async_tools import acall, AsyncOnStop, AsyncOnStart

from kafka_client.topic_sender import TopicSender
from .kafka_dict_consumer import KafkaDictConsumer
from ..deserializer import RecordDeserializer
from ..kafka_client.producer import Producer
from ..serializer import RecordSerializer
from ..record_processor.kafka_record import KafkaRecord
from ..record_processor.record_processor import RecordProcessor

logger = logging.getLogger(__file__)


@unique
class KafkaEventType(IntEnum):
    new_key = 1
    key_update = 2
    key_delete = 3


@dataclass
class KafkaMsgEvent:
    key: Any
    msg: Any
    event_type: KafkaEventType


RemoveKeyFromTopicMessage = None


class KafkaTopicDict(
    # TODO figure  out is it possible to use UserDict
    # UserDict,
    AsyncOnStart,
    AsyncOnStop
):

    """
    *** dict like class for kafka topic
    """

    @dataclass(frozen=True)
    class Config:
        topic: str

    @dataclass
    class Context:
        serializer: RecordSerializer
        deserializer: RecordDeserializer
        producer: Producer
        consumer: KafkaDictConsumer

    def __init__(self,
                 config: Config,
                 context: Context
                 ) -> None:
        self.config = config
        self.context = context

        if not isinstance(self.context.consumer, KafkaDictConsumer):
            raise ValueError(f"consumer should be an instance of KafkaDictConsumer")

        self._topic_sender: TopicSender = TopicSender(
            self.config, TopicSender.Context(
                serializer=self.context.serializer,
                producer=self.context.producer))

        self._processor = RecordProcessor(
            config=self.config,
            context=RecordProcessor.Context(
                self.context.consumer,
                self.context.deserializer),
            record_callback=self._process_record)
        self._dict: Dict[Any, Any] = {}

        # TODO consider to remove _key_to_msg_event
        self._key_to_msg_event: Dict[Any, asyncio.Event] = defaultdict(asyncio.Event)
        self._partition_to_on_init_max_offset: Dict[int, int] = defaultdict(int)
        self._partition_to_current_offset: Dict[int, int] = defaultdict(int)
        self._partition_to_got_new_msg_event: Dict[int, asyncio.Event] = defaultdict(asyncio.Event)
        self._inited_event: asyncio.Event = asyncio.Event()
        self._process_record_tasks: List[asyncio.Task] = []

        self._event_type_to_callback: Dict[KafkaEventType, Callable[[str, Any], Any]] = {}

        self.event_type_to_local_action = {
            KafkaEventType.new_key: self._setitem,
            KafkaEventType.key_delete: self._pop,
            KafkaEventType.key_update: self._setitem,
        }

        self.context.consumer.register_call_before_fetch(self.config.topic, self._fill_partitions_offsets)

    async def _on_start(self) -> None:
        logger.info(f"{type(self).__name__} _on_start started")
        await self._init()
        logger.info(f"{type(self).__name__} _on_start finished")

    async def _on_stop(self) -> None:
        logger.info(f"{type(self).__name__} for {self.config.topic} _on_stop")
        try:
            await asyncio.gather(*self._process_record_tasks)
        except Exception as e:
            logger.info(f"{type(self).__name__}  for {self.config.topic} error in _on_stop:  {repr(e)}")
        logger.info(f"{type(self).__name__} for {self.config.topic} _on_stop done")

    async def _fill_partitions_offsets(self):
        partition_to_on_init_max_offset =\
            await self.context.consumer.get_topic_partition_to_max_offset(self.config.topic)
        for partition, max_offset in partition_to_on_init_max_offset.items():
            self._partition_to_on_init_max_offset[partition.partition] = max_offset
            self._partition_to_current_offset[partition.partition] = 0

    async def _init(self) -> None:
        if not self.context.consumer.is_connected():
            raise RuntimeError(f"consumer: {self.context.consumer} should be inited "
                               f"before kafka dict: {self.config.topic}")
        if not self.context.producer.is_connected():
            raise RuntimeError(f"producer: {self.context.producer} should be inited")
        logger.info(f"waiting {self.config.topic} on init read")
        await self._wait_on_init_read()
        self._inited_event.set()
        logger.info(f"{self.config.topic} init done")

    def get(self, key: str, default: Optional[Any] = None) -> Union[dict, Any]:
        assert self._inited_event.is_set(), "called before init"
        return self._dict.get(key, default)

    async def aget(self, key: str, default: Optional[Any] = None) -> Union[dict, Any]:
        assert self._inited_event.is_set(), "called before init"
        if key in self._key_to_msg_event:
            await self._key_to_msg_event[key].wait()
        return self._dict.get(key, default)

    def copy(self) -> Dict[Any, Any]:
        assert self._inited_event.is_set(), "called before init"
        return deepcopy(self._dict)

    def set(self, key: Any, value: Any) -> None:
        """
        *** publish value to kafka
        """
        asyncio.ensure_future(self._topic_sender.produce(value, key))
        self._key_to_msg_event[key].clear()

    def pop(self, key: Any) -> Any:
        val = self._dict[key]
        self.set(key, RemoveKeyFromTopicMessage)
        return val

    async def set_with_confirm(self, key: Any, value: Any) -> None:
        """
        *** publish value to kafka
        *** wait until it appears in _topic_key_to_value and could be reached using []
        """
        assert self._inited_event.is_set(), "called before init"
        self.set(key, value)
        await self._key_to_msg_event[key].wait()

    async def pop_with_confirm(self, key: Any) -> Any:
        assert self._inited_event.is_set(), "called before init"
        val = self._dict[key]
        await self.set_with_confirm(key, RemoveKeyFromTopicMessage)
        return val

    def items(self):
        assert self._inited_event.is_set(), "called before init"
        return self._dict.items()

    def keys(self):
        assert self._inited_event.is_set(), "called before init"
        return self._dict.keys()

    async def wait_all_msgs(self) -> None:
        assert self._inited_event.is_set(), "called before init"
        for msg_event in self._key_to_msg_event.values():
            await msg_event.wait()

    def subscribe_on_key_update(self, processor: Callable[[Any], Any]) -> None:
        assert not self._inited_event.is_set(), "called after init"
        self._event_type_to_callback[KafkaEventType.key_update] = processor

    def subscribe_on_new_key(self, processor: Callable[[Any], Any]) -> None:
        assert not self._inited_event.is_set(), "called after init"
        self._event_type_to_callback[KafkaEventType.new_key] = processor

    def subscribe_on_key_delete(self, processor: Callable[[Any], Any]) -> None:
        assert not self._inited_event.is_set(), "called after init"
        self._event_type_to_callback[KafkaEventType.key_delete] = processor

    async def _trigger_registered_callback(self, msg_event: KafkaMsgEvent) -> None:
        if msg_event.event_type in self._event_type_to_callback:
            kwargs = self._parse_msg_event(msg_event)

            msg_processor = self._event_type_to_callback[msg_event.event_type]
            try:
                await acall(msg_processor(**kwargs))
            except Exception as e:
                logger.warning(f"got error in handling key: {msg_event.key} "
                               f"msg: {msg_event.msg} "
                               f"event: {msg_event.event_type} "
                               f"error: {repr(e)}")

    def _trigger_local_action(self, msg_event: KafkaMsgEvent) -> None:
        kwargs = self._parse_msg_event(msg_event)
        if msg_event.key in self._key_to_msg_event:
            self._key_to_msg_event[msg_event.key].set()
        self.event_type_to_local_action[msg_event.event_type](**kwargs)

    @staticmethod
    def _parse_msg_event(msg_event: KafkaMsgEvent) -> dict[str, Any]:
        kwargs = {"key": msg_event.key}
        if msg_event.event_type != KafkaEventType.key_delete:
            kwargs["value"] = msg_event.msg
        return kwargs

    async def _wait_on_init_read(self) -> None:
        """
        *** after on init read max offset is
        *** not monitoring.
        """
        for partition in self._partition_to_current_offset:
            while self._partition_to_current_offset[partition] < self._partition_to_on_init_max_offset[partition] - 1:
                logger.debug(f" waiting fo full read for topic : {self.config.topic},"
                             f" current offset: {self._partition_to_current_offset[partition]}"
                             f" max offset: {self._partition_to_on_init_max_offset[partition]}")
                await self._partition_to_got_new_msg_event[partition].wait()
                self._partition_to_got_new_msg_event[partition].clear()

    def _process_record(self, record: KafkaRecord):
        self._partition_to_current_offset[record.partition] = record.offset
        self._partition_to_got_new_msg_event[record.partition].set()
        if record.value is RemoveKeyFromTopicMessage:
            event = KafkaEventType.key_delete
            logger.debug(f"{record.key} was removed from {self.config.topic} kafka_dict")
        elif record.key in self._dict:
            event = KafkaEventType.key_update
            logger.debug(f"{record.key} was updated in topic: {self.config.topic} kafka_dict")
        else:
            event = KafkaEventType.new_key
            logger.debug(f"{record.key} was added to {self.config.topic} kafka_dict")
        msg_event = KafkaMsgEvent(record.key, record.value, event)
        self._trigger_local_action(msg_event)
        if self._inited_event.is_set():
            self._process_record_tasks.append(asyncio.create_task(self._trigger_registered_callback(msg_event)))

    def _pop(self, key: str) -> Any:
        if key not in self._dict:
            logger.warning(f" deleting not existing key: {key}")
            return
        return self._dict.pop(key)

    def _setitem(self, key: str, value: Any) -> None:
        self._dict[key] = value

    def __getitem__(self, key: str) -> Union[dict, Any]:
        assert self._inited_event.is_set(), "called before init"
        return self._dict[key]

    def __setitem__(self, key: str, value: Any) -> None:
        self.set(key, value)

    def __iter__(self):
        assert self._inited_event.is_set(), "called before init"
        return self._dict.__iter__()

    # def __delitem__(self, key):
    #     del self.__dict__[key]
