#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>

import asyncio

from aiohttp import ClientSession, ClientResponse
from aiohttp.typedefs import StrOrURL
from http_tools import HttpStatusCode

from ..metrics_registry import MetricsRegistry


class MonitoredClientSession(ClientSession):
    def __init__(self, metrics_register: MetricsRegistry, *args, **kwargs) -> None:
        self._registry = metrics_register
        super().__init__(*args, **kwargs)

    async def _request(self, method: str, str_or_url: StrOrURL, **kwargs) -> ClientResponse:
        labels = {'method': method.upper(), 'url': str_or_url}
        self._registry.counter('http_client___requests', labels).inc()

        with self._registry.track_progress('http_client__progress_requests', labels):
            buckets = self._registry.config.default_buckets
            labels['status_code'] = HttpStatusCode.ServiceUnavailable.value
            with self._registry.track_time('http_client___requests_latency', buckets, labels):
                try:
                    response = await super()._request(method, str_or_url, **kwargs)
                    labels['status_code'] = response.status
                except asyncio.TimeoutError:
                    labels['status_code'] = HttpStatusCode.GatewayTimeout.value
                    raise

        self._registry.counter('http_client__input_traffic', labels).inc(response.content_length or 0)

        return response
