from dataclasses import dataclass, field
from typing import Type, List, Optional, Union

from ..elements.aggregation_function import AggregationFunction
from ..elements.order import Order
from ..elements.filter_by import Filter
from ..elements.output_format import OutputFormat
from ..elements.result import SelectResult
from ..elements.table import Table
from ..elements.column import Column
from .abstract_statement import AbstractStatement
from ..utils.selectable import Selectable


@dataclass
class Select(AbstractStatement):
    table: Type[Table]

    columns: Optional[List[Column]] = field(default_factory=list)
    filter_by: Optional[List[Filter]] = field(default_factory=list)
    search_by: Optional[List[Filter]] = field(default_factory=list)
    group_by: Optional[List[Union[str, Column]]] = field(default_factory=list)
    order_by: Optional[List[Order]] = field(default_factory=list)
    offset: Optional[int] = None
    limit: Optional[int] = None
    output_format: OutputFormat = OutputFormat.Dict
    aggregation_functions: Optional[List[AggregationFunction]] = field(default_factory=list)

    def __post_init__(self):
        self.table.get_table_engine().preprocess_select_request(self)

        self.columns = self.columns if self.columns is not None else []
        self.search_by = self.search_by if self.search_by is not None else []
        self.filter_by = self.filter_by if self.filter_by is not None else []
        self.group_by = self.group_by if self.group_by is not None else []
        self.order_by = self.order_by if self.order_by is not None else []
        self.aggregation_functions = self.aggregation_functions if self.aggregation_functions is not None else []

    @property
    def selectable(self) -> List[Selectable]:
        # noinspection PyTypeChecker
        return self.columns + self.aggregation_functions

    def generate_sql(self):
        table_name = self.table.get_name()
        select_sql = ', '.join(column.to_selector() for column in self.selectable)
        where_filter_part = ' AND '.join((str(f) for f in self.filter_by))
        if where_filter_part:
            where_filter_part = f'({where_filter_part})'
        where_search_part = ' OR '.join((str(f) for f in self.search_by))
        if where_search_part:
            where_search_part = f'({where_search_part})'
        where_parts_union_connector = ' AND ' if (where_filter_part and where_search_part) else ''
        full_where = where_filter_part + where_parts_union_connector + where_search_part
        where_sql = f'\nWHERE {full_where}' if full_where else ''

        group_by_part = ', '.join((str(g) for g in self.group_by))
        group_sql = f'\nGROUP BY {group_by_part}' if group_by_part else ''

        order_by_part = ', '.join((str(o) for o in self.order_by))
        order_sql = f'\nORDER BY {order_by_part}' if order_by_part else ''

        limit_sql = f'\nLIMIT {self.limit}' if self.limit else ''
        offset_sql = f'\nOFFSET {self.offset}' if self.offset else ''
        return f"SELECT {select_sql} FROM {table_name} {where_sql} {group_sql} {order_sql} {limit_sql} {offset_sql}"

    def form_result(self, payload: str) -> SelectResult:
        return SelectResult(payload, self.selectable, self.output_format)
