#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>
from contextlib import contextmanager, asynccontextmanager
from dataclasses import dataclass
from threading import Lock
from typing import Sequence, Any, Iterator, Callable, Awaitable

from async_tools import acall
from more_itertools import first
from prometheus_client import CollectorRegistry, Counter, Gauge, Summary, Histogram
from prometheus_client.utils import INF

from .tools.call import call
from .tools.timer import Timer


class BaseMetricsRegistry:
    @dataclass
    class Config:
        default_buckets: tuple[float] = (.001, .003, .01, .03, .1, .3, 1., 3., 10., 30., 100., 300., 1000., INF)

    def __init__(self, config: Config) -> None:
        self.config = config
        self._lock = Lock()
        self._registry = CollectorRegistry(auto_describe=True)

    def counter(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Counter:
        return self._ensure_collector(Counter, name, labels, doc)

    def gauge(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Gauge:
        return self._ensure_collector(Gauge, name, labels, doc)

    def summary(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Summary:
        return self._ensure_collector(Summary, name, labels, doc)

    def histogram(
            self, name: str, buckets: Sequence[float] | None = None, labels: dict[str, Any] | None = None, doc: str = ''
    ) -> Histogram:
        return self._ensure_collector(Histogram, name, labels, doc, buckets=buckets or self.config.default_buckets)

    def _ensure_collector(
            self, required_type: type, name: str, labels: dict[str, Any] | None = None, doc: str = '', **kwargs
    ) -> Counter | Gauge | Summary | Histogram:
        if not issubclass(required_type, (Counter, Gauge, Summary, Histogram)):
            raise TypeError(f'Inappropriate type: {required_type}')

        with self._lock:
            labels = labels or {}

            collector = self._registry._names_to_collectors.get(name)  # bad library interface
            if not collector:
                collector = required_type(name, doc, labels.keys(), registry=self._registry, **kwargs)
            elif not isinstance(collector, required_type):
                raise ValueError(f"Collector '{name}' with the '{type(collector)}' type already exists ")

            if mismatched_labels := set(collector._labelnames) ^ set(labels.keys()):
                raise ValueError(f'mismatched labels: {mismatched_labels}')

            if label_values := labels.values():
                collector = collector.labels(*label_values)

        return collector


class MetricsRegistry(BaseMetricsRegistry):
    @contextmanager
    def track_progress(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Iterator[None]:
        with self.gauge(name, labels, doc).track_inprogress():
            yield

    @contextmanager
    def track_time(
            self, name: str, buckets: Sequence[float] | None = None, labels: dict[str, Any] | None = None, doc: str = ''
    ) -> Iterator[None]:
        try:
            with Timer() as timer:
                yield
        finally:
            self.histogram(name, buckets, labels, doc).observe(timer.duration)

    @contextmanager
    def track(
            self,
            name: str,
            labels: dict[str, Any] | None = None,
            *,
            except_labels: dict[str, Any] | None = None,
            missing_label_value: Callable[[], Any] | Any | None = None
    ) -> Iterator[None]:
        missing_label_values = []
        with self._track(missing_label_values, name, labels, except_labels):
            yield
            missing_label_values.append(call(missing_label_value))

    @asynccontextmanager
    async def atrack(
            self,
            name: str,
            labels: dict[str, Any] | None = None,
            *,
            except_labels: dict[str, Any] | None = None,
            missing_label_value: Callable[[], Awaitable[Any] | Any] | Any | None = None
    ) -> Iterator[None]:
        missing_label_values = []
        with self._track(missing_label_values, name, labels, except_labels):
            yield
            missing_label_values.append(await acall(missing_label_value))

    @contextmanager
    def _track(
            self,
            missing_label_values: list[Any],
            name: str,
            labels: dict[str, Any] | None = None,
            except_labels: dict[str, Any] | None = None
    ) -> Iterator[None]:
        labels = labels or {}
        except_labels = except_labels or {}
        required_keys = labels.keys() | except_labels.keys()
        with self.track_progress(f'{name}__progress', labels):
            with self.track_time(f'{name}__spent_time', self.config.default_buckets, labels):
                try:
                    yield
                    labels |= {key: first(missing_label_values, None) for key in required_keys - labels.keys()}
                except Exception:
                    labels |= except_labels
                    raise
                finally:
                    self.counter(f'{name}', labels).inc()
