#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import abc
import itertools
from dataclasses import dataclass, field
from typing import ClassVar, Any

from entity_read.sql.atoms.node import Node
from entity_read.sql.context import SqlContext
from .abstract import Selectable
from .column import Column


@dataclass(kw_only=True, frozen=True)
class Aggregation(Selectable, abc.ABC):
    args: tuple[Selectable[Any], ...]
    filters: tuple[Selectable[Any], ...] = tuple()
    code: ClassVar[str]
    keep_subquery: ClassVar[bool] = False

    def requires(self) -> list[Column]:
        return list(itertools.chain(*[
            arg.requires() for arg in itertools.chain(self.args, self.filters) if isinstance(arg, Node)
        ]))
    
    def get_name(self) -> str:
        return f'{self.code}({",".join([arg.get_name() for arg in self.args])})'

    def to_sql(self, context: SqlContext, key_to_expression: dict = None) -> str:
        filters_sql = ''
        if self.filters:
            filters_sql = f' FILTER(WHERE {" AND ".join([f.to_sql(context, key_to_expression) for f in self.filters])})'
        return self._to_sql([arg.to_sql(context, key_to_expression) for arg in self.args], filters_sql)

    @abc.abstractmethod
    def _to_sql(self, args_sql: list[str], filters_sql: str) -> str:
        ...
