import logging
import time
from dataclasses import dataclass
from typing import Type, Optional, List

from aiohttp import BasicAuth
from http_tools.http_server_connector import HttpServerConnector

from .elements.aggregation_function import AggregationFunction
from .elements.output_format import OutputFormat
from .statements.abstract_statement import AbstractStatement
from .statements.count import Count
from .statements.create_table import CreateTable
from .statements.insert import Insert
from .utils.charset import Charset
from .statements.select import Select
from .elements.result import Result
from .elements.filter_by import Filter
from .elements.order import Order
from .elements.column import Column
from .elements.table import Table


logger = logging.getLogger(__file__)


class ClickHouseConnector:
    @dataclass
    class BaseConfig:
        database: str
        login: str
        password: str

    @dataclass
    class Config(HttpServerConnector.Config, BaseConfig):
        compression: bool = False

    Context = HttpServerConnector.Context

    def __init__(self, config: Config, context: Context):
        self.config = config
        self._http_connector = HttpServerConnector(config, context)

    async def execute(self, sql: str) -> str:
        logger.debug(f'sql({sql}):')
        request_start_time = time.time()
        payload = await self._http_connector.post(
            self.config.location,
            payload=sql.encode(Charset.UTF8),
            headers={"Accept-Encoding": "gzip"} if self.config.compression else None,
            url_query={"enable_http_compression": 1} if self.config.compression else None,
            auth=BasicAuth(login=self.config.login, password=self.config.password)
        )
        request_end_time = time.time()
        logger.debug(f'rows got in {round(request_end_time - request_start_time, 3)} sec')
        return payload

    async def execute_statement(self, statement: AbstractStatement) -> Result:
        payload = await self.execute(sql=statement.generate_sql())
        return statement.form_result(payload=payload)

    async def select(self,
                     table: Type[Table],
                     *,
                     columns: Optional[List[Column]] = None,
                     filter_by: Optional[List[Filter]] = None,
                     search_by: Optional[List[Filter]] = None,
                     group_by: Optional[List[Column]] = None,
                     order_by: Optional[List[Order]] = None,
                     offset: Optional[int] = None,
                     limit: Optional[int] = None,
                     output_format: OutputFormat = OutputFormat.Dict,
                     aggregation_functions: Optional[List[AggregationFunction]] = None
                     ) -> list:
        result = await self.execute_statement(
            statement=Select(
                table, columns, filter_by, search_by, group_by,
                order_by, offset, limit, output_format, aggregation_functions
            )
        )
        return result.fetchall()

    async def count(self,
                    table: Type[Table],
                    *,
                    filter_by: Optional[List[Filter]] = None,
                    search_by: Optional[List[Filter]] = None,
                    ) -> int:
        result = await self.execute_statement(statement=Count(table, filter_by, search_by))
        return result.fetchone()

    async def insert(self, data_to_insert: List[Table]) -> None:
        await self.execute_statement(statement=Insert(data_to_insert))

    async def create_table(self, table: Type[Table]) -> None:
        await self.execute_statement(statement=CreateTable(table))
