#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
from dataclasses import dataclass, field

import entity_read.sql.inter.agg_func
from entity_read import sql
from entity_read.entity import Entity
from sqlalchemy.util import to_list

from .expression import Expression
from .subquery import SubQuery


key_to_function_factory = {
    "array": entity_read.sql.inter.agg_func.AggregationArray,
    "unique": entity_read.sql.inter.agg_func.AggregationUnique,
    "count": entity_read.sql.inter.agg_func.AggregationCount,
    "sum": entity_read.sql.inter.agg_func.AggregationSum,
    "max": entity_read.sql.inter.agg_func.AggregationMax,
    "min": entity_read.sql.inter.agg_func.AggregationMin,
}


@dataclass(frozen=False)
class Aggregation(Expression):
    key: str
    args: list[Expression]
    filter: list[Expression] = field(default_factory=list)

    def __repr_in_dumps__(self) -> str | dict:
        self_dict_repr: dict = super().__repr_in_dumps__()
        arg_reprs = [arg.__repr_in_dumps__() for arg in self_dict_repr["args"]]
        if all(isinstance(arg_repr, str) for arg_repr in arg_reprs):
            return f'agg.{self.key}({",".join(arg_reprs)})'
        self_dict_repr['key'] = f"'{self_dict_repr['key']}'"
        self_dict_repr['args'] = arg_reprs
        return self_dict_repr

    def _get_lower_type(self) -> type[sql.atoms.Aggregation]:
        if (function_factory := key_to_function_factory.get(self.key)) is None:
            raise KeyError(f"Not found function {self.key!r}")
        return function_factory

    def eval(self, entity_type: type[Entity], variables: dict[str, sql.atoms.Selectable]) -> sql.atoms.Aggregation:
        lower_type = self._get_lower_type()
        # noinspection PyArgumentList
        return lower_type(
            args=[arg.eval(entity_type, variables) for arg in self.args],
            filters=[f.eval(entity_type, variables) for f in self.filter]
        )

    def shortcut(self) -> str:
        return f"agg.{self.key}({','.join([arg.shortcut() for arg in self.args])})"

    def where(self, filter: Expression | list[Expression]) -> 'Aggregation':
        return Aggregation(key=self.key, args=self.args, filter=to_list(filter))


class AggregationArray(Aggregation):
    def __init__(self, query: SubQuery):
        super().__init__(key="array", args=[query])


class AggregationUnique(Aggregation):
    def __init__(self, query: SubQuery):
        super().__init__(key="unique", args=[query])


class AggregationCount(Aggregation):
    def __init__(self, query: SubQuery):
        super().__init__(key='count', args=[query])


class AggregationSum(Aggregation):
    def __init__(self, query: SubQuery):
        super().__init__(key='sum', args=[query])


class AggregationMax(Aggregation):
    def __init__(self, query: SubQuery):
        super().__init__(key='max', args=[query])


class AggregationMin(Aggregation):
    def __init__(self, query: SubQuery):
        super().__init__(key='min', args=[query])
