#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasiliev Ivan <i.vasiliev@technokert.ru>
from dataclasses import dataclass

from ....elements.aggregation_function import AggregationFunction, AggregationFunctionType
from ....elements.column import Column
from ....elements.filter_by import Condition
from ..base import AbstractFactory
from ..functions.function import Function


@dataclass
class AggFunction(Function):

    # TODO consider how to work with count(*)
    def __call__(self, column: Column) -> AggregationFunction:
        return AggregationFunction(
            AggregationFunctionType(self.name),
            column
        )


@dataclass
class ConditionAggFunction(Function):

    # TODO consider how to work with count(*)
    def __call__(self, column: Column, condition: Condition.get_condition_type()) -> AggregationFunction:
        return AggregationFunction(
            AggregationFunctionType(self.name),
            column,
            condition
        )


@dataclass
class AggFunctionFactory(AbstractFactory):

    def __getattr__(self, key: str):
        parts = key.split("_")
        func_name = parts[0]
        got_condition = len(parts) == 2 and parts[1] == "if"
        if got_condition:
            return ConditionAggFunction(func_name)
        return AggFunction(func_name)
