#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Revva Konstantin <k.revva@abm-jsc.ru>


from dataclasses import dataclass
from typing import Callable, Awaitable, Any, Type
import json
import logging

from aiokafka import ConsumerRecord

from kafka_tools.kafka_client.consumer import Consumer
from kafka_tools.kafka_client.producer import Producer
from init_helpers.dict_to_dataclass import dict_to_dataclass


logger = logging.getLogger(__name__)


@dataclass
class DetectorAnalyzeResult:
    rtsp_url: str
    snapshot_file_id: str
    snapshot_done_at: float
    reference_image_file_ids: list[str]
    detector_requests: list[dict[str, str | float | bool | int]]
    detector_results: list[dict[str, str | float | bool | int | list]]
    comparison_settings: dict[str, str | float | list]
    comparison_results: list[float]
    result: float
    analyzed_at: float


AnalyzeResultCallback = Callable[[DetectorAnalyzeResult], Awaitable[Any]]


class DetectorKafkaConnector:
    @dataclass
    class Config:
        detector_task_topic: str
        detector_result_topic: str

    @dataclass
    class Context:
        kafka_producer: Producer
        kafka_consumer: Consumer

    def __init__(self,
                 config: Config,
                 context: Context,
                 detector_analyze_result_type: Type[DetectorAnalyzeResult]) -> None:
        self.config = config
        self.context = context
        self.detector_analyze_result_type = detector_analyze_result_type
        self._analyze_result_callbacks: set[AnalyzeResultCallback] = set()
        self.context.kafka_consumer.subscribe(self.config.detector_result_topic, self._process_task_result_kafka_record)

    async def create_analyze_task(self, stream_url: str, snapshot_id: str) -> None:
        payload = {
            "stream_url": stream_url,
            "snapshot_id": snapshot_id,
        }
        await self.context.kafka_producer.produce(self.config.detector_task_topic, json.dumps(payload).encode())
        logger.debug(f"sent snapshot_id {snapshot_id} for {stream_url} to {self.config.detector_task_topic}")

    def subscribe_on_analyze_result(self, callback: AnalyzeResultCallback) -> None:
        self._analyze_result_callbacks.add(callback)

    async def _process_task_result_kafka_record(self, record: ConsumerRecord) -> None:
        analyze_result = self._deserialize_analyze_result(record)
        for callback in self._analyze_result_callbacks:
            await callback(analyze_result)
        await self.context.kafka_consumer.commit(record)

    def _deserialize_analyze_result(self, record: ConsumerRecord) -> DetectorAnalyzeResult:
        payload = json.loads(record.value)
        return dict_to_dataclass(payload, self.detector_analyze_result_type)
