import itertools
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Dict, Type, Union, Optional
from async_tools import AsyncInitable
from dict_caster.extras import to_list

from .elements.aggregation_function import AggregationFunction
from .elements.filter_by import Filter, Condition
from .elements.order import Order
from .elements.column import Column
from .exceptions import WrongArgumentsException
from .elements.table import Table, OrderRepresentations, ColumnsRepresentations, ALL_COLUMNS
from .elements.output_format import OutputFormat
from .clickhouse_connector import ClickHouseConnector


logger = logging.getLogger(__file__)


class ClickHouseInterface(AsyncInitable):
    @dataclass
    class Context:
        clickhouse_connector: ClickHouseConnector

    def __init__(self, context: Context):
        AsyncInitable.__init__(self)
        self.context = context
        self._name_to_entity: Dict[str, Type[Table]] = {}

    async def async_init(self) -> None:
        for table in self._name_to_entity.values():
            await self.context.clickhouse_connector.create_table(table)

    def register_entity(self, entity: Type[Table]) -> None:
        self._name_to_entity[entity.__table_name__] = entity
        logger.info(f'{entity} registered')

    def get_entity_by_name(self, entity_name: str) -> Type[Table]:
        if entity := self._name_to_entity.get(entity_name):
            return entity
        raise WrongArgumentsException(f'Unknown entity: {entity_name}')

    async def get(self,
                  entity_name: str,
                  *,
                  columns: ColumnsRepresentations = ALL_COLUMNS,
                  filter_by: Optional[List[dict]] = None,
                  search_by: Optional[List[dict]] = None,
                  order_by: OrderRepresentations = None,
                  group_by: Optional[List[str]] = None,
                  offset: int = 0,
                  limit: int = 100,
                  output_format: str = OutputFormat.Dict.value,
                  aggregation_functions: Optional[List[dict]] = None
                  ) -> list:

        entity = self.get_entity_by_name(entity_name)

        columns: List[Column] = entity.columns_from_str(columns)
        filter_by: List[Filter] = entity.filters_from_str(filter_by)
        search_by: List[Filter] = entity.filters_from_str(search_by)
        order_by: List[Order] = entity.orders_from_str(order_by)
        group_by: List[Column] = entity.groups_from_str(group_by)
        output_format: OutputFormat = entity.output_format_from_str(output_format)
        aggregation_functions: List[AggregationFunction] = entity.aggregation_functions_from_str(aggregation_functions)
        selectables = list(itertools.chain.from_iterable((columns or [], aggregation_functions or [])))

        condition = Condition()
        for filter_ in filter_by:
            condition.add_and(filter_)
        for search in search_by:
            condition.add_or(search)
        result = await self.context.clickhouse_connector.select(
            entity, selectables=selectables, condition=condition, order_by=order_by, group_by=group_by,
            offset=offset, limit=limit, output_format=output_format)
        return result

    async def count(self,
                    entity_name: str,
                    *,
                    filter_by: List[dict] = None,
                    search_by: List[dict] = None,
                    ) -> int:
        entity = self.get_entity_by_name(entity_name)

        filter_by: List[Filter] = entity.filters_from_str(filter_by)
        search_by: List[Filter] = entity.filters_from_str(search_by)

        result = await self.context.clickhouse_connector.count(entity, filter_by=filter_by, search_by=search_by)
        return result

    async def add(self, entity_name: str, values: Union[dict, List[dict]]) -> None:
        values = to_list(values)

        entity = self.get_entity_by_name(entity_name)

        grouped_values = defaultdict(list)
        for val in values:
            try:
                prepared_val = await entity.constructor(**val)
                insert_column_names = tuple(prepared_val.to_insert_columns())
                grouped_values[insert_column_names].append(prepared_val)
            except (ValueError, TypeError) as e:
                logger.error(f'failed to construct entity: {entity} from {val}, because: {repr(e)}')

        for data_to_insert in grouped_values.values():
            await self.context.clickhouse_connector.insert(data_to_insert)
