#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
import enum
import itertools
import typing

from ..viewable_entity import ViewableEntity


class SqlContext:
    def __init__(self, self_table_name: str, key_to_param: typing.Optional[dict] = None):
        self.self_table_name = self_table_name
        self.bind_params = key_to_param is None
        self.key_to_param = key_to_param

    def get_self(self):
        return self.self_table_name

    def register_parameter(self, value):
        key = f'l{len(self.key_to_param)}'
        self.key_to_param[key] = value
        return ":" + key


class NodeType(enum.StrEnum):
    root = 'root'
    # list = 'list'
    # object = 'object'
    children = 'children'
    # children = 'children'
    column = 'column'
    alias = 'alias'
    literal = 'literal'
    function = 'function'
    order = 'order'
    # tuple = 'tuple'


class EFunctionType(enum.StrEnum):
    eq = "eq"
    ge = "ge"
    gt = "gt"
    le = "le"
    lt = "lt"
    included = "in"
    string_contains = "string_contains"
    now = "now"
    len = "len"
    cast_to_str = "cast_to_str"
    is_null = "is_null"
    not_null = "not_null"
    or_ = "or_"
    and_ = "and_"


class EColumnType(enum.StrEnum):
    boolean = "boolean"


@dataclasses.dataclass
class Node:
    type: NodeType = dataclasses.field(init=False)

    # TODO: replace key_to_expression with "existent attributes", where present table columns and relations
    def to_sql(self, context, key_to_expression = None) -> str:
        ...

    def requires(self) -> typing.List['EColumn']:
        ...


class Selectable:
    def get_name(self):
        ...


@dataclasses.dataclass
class EColumn(Node, Selectable):
    type = NodeType.column
    value: str
    # alias: str = None

    def to_sql(self, context, key_to_expression = None) -> str:
        key_to_expression = {} if key_to_expression is None else key_to_expression
        result = key_to_expression[self.value] if self.value in key_to_expression else f'{context.get_self()}.{self.value}'
        return result

    def requires(self) -> typing.List['EColumn']:
        return [self]

    def get_name(self) -> str:
        return self.value


@dataclasses.dataclass
class EAlias(Node, Selectable):
    type = NodeType.alias
    value: Node
    alias: str

    def to_sql(self, context, key_to_expression = None) -> str:
        return self.value.to_sql(context, key_to_expression)

    def requires(self) -> typing.List['EColumn']:
        return self.value.requires()

    def get_name(self) -> str:
        return self.alias


@dataclasses.dataclass
class ELiteral(Node):
    type = NodeType.literal
    value: typing.Any
    # alias: str = None

    def to_sql(self, context, key_to_expression = None) -> str:
        if not context.bind_params:
            return context.register_parameter(self.value)
        result = str(self.value)
        if isinstance(self.value, list):
            if self.value:
                result = 'ARRAY' + result
            else:
                result = "'{}'"
        return result

    def requires(self) -> typing.List['EColumn']:
        return []

    # def get_name(self) -> str:
    #     return self.alias


@dataclasses.dataclass
class EFunction(Node):
    type = NodeType.function
    value: EFunctionType = dataclasses.field(init=False)
    args: typing.List[typing.Union[EColumn, 'EFunction', ELiteral]] = dataclasses.field(default_factory=list)
    kwargs: typing.Dict[str, typing.Union[EColumn, 'EFunction', ELiteral]] = dataclasses.field(default_factory=dict)

    def _convert_args_to_sql(self, context, key_to_expression) -> typing.Tuple[typing.List[str], typing.Dict[str, str]]:
        args = [arg.to_sql(context, key_to_expression) for arg in self.args]
        kwargs = {k: v.to_sql(context, key_to_expression) for k, v in self.kwargs.items()}
        return args, kwargs

    def requires(self) -> list[EColumn]:
        required = []
        for arg in itertools.chain(self.args, self.kwargs.values()):
            if isinstance(arg, Node):
                required += arg.requires()
        return required

    def to_sql(self, context, key_to_expression = None) -> str:
        ...


@dataclasses.dataclass
class EOrder(Node):
    type = NodeType.order
    value: typing.Union[EColumn, EFunction]
    asc: bool = True
    nulls_last: typing.Optional[bool] = None

    def to_order_sql(self, context, key_to_expression = None) -> str:
        expression_str = self.value.to_sql(context, key_to_expression)
        if not self.asc:
            expression_str += " DESC"
        if self.nulls_last is not None:
            expression_str += " NULLS " + ("FIRST" if self.nulls_last else "LAST")
        return expression_str

    def requires(self) -> typing.List[EColumn]:
        return self.value.requires()

# @dataclasses.dataclass
# class EAnyFunction(Node):
#     type = NodeType.function
#     value: EFunctionType
#     args: typing.List[typing.Any] = dataclasses.field(default_factory=list)
#     kwargs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class EFunctionLE(EFunction):
    value = EFunctionType.le

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f'{args_sql[0]} <= {args_sql[1]}'


@dataclasses.dataclass
class EFunctionEQ(EFunction):
    value = EFunctionType.eq

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f'{args_sql[0]} = {args_sql[1]}'


@dataclasses.dataclass
class EFunctionGE(EFunction):
    value = EFunctionType.ge

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f'{args_sql[0]} >= {args_sql[1]}'


@dataclasses.dataclass
class EFunctionGT(EFunction):
    value = EFunctionType.gt

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f'{args_sql[0]} > {args_sql[1]}'


@dataclasses.dataclass
class EFunctionLT(EFunction):
    value = EFunctionType.lt

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f'{args_sql[0]} < {args_sql[1]}'


@dataclasses.dataclass
class EFunctionIncluded(EFunction):
    value = EFunctionType.included

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        if args_sql[1] != "'{}'":
            return f'{args_sql[0]} = ANY ({args_sql[1]})'
        return f'FALSE'

    def requires(self) -> typing.List[EColumn]:
        if not self.args[1]:
            return []
        return super().requires()

@dataclasses.dataclass
class EFunctionStringContains(EFunction):
    value = EFunctionType.string_contains

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f"{args_sql[0]} ilike {args_sql[1]}"


@dataclasses.dataclass
class EFunctionCastToStr(EFunction):
    value = EFunctionType.cast_to_str

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f"{args_sql[0]}::TEXT"


@dataclasses.dataclass
class EFunctionIsNull(EFunction):
    value = EFunctionType.is_null

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f"{args_sql[0]} IS NULL"


@dataclasses.dataclass
class EFunctionNotNull(EFunction):
    value = EFunctionType.not_null

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f"{args_sql[0]} IS NOT NULL"


@dataclasses.dataclass
class EFunctionOr(EFunction):
    value = EFunctionType.or_

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f"({' OR '.join(args_sql)})"


@dataclasses.dataclass
class EFunctionAnd(EFunction):
    value = EFunctionType.and_

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f"({' AND '.join(args_sql)})"

#
# @dataclasses.dataclass
# class EFunctionCast(EFunction):
#     value = EFunctionType.cast
#
#     def to_sql(self):
#         ...

@dataclasses.dataclass
class EFunctionLen(EFunction):
    value = EFunctionType.len

    def to_sql(self, context, key_to_expression = None) -> str:
        args_sql, kwargs_sql = self._convert_args_to_sql(context, key_to_expression)
        return f'jsonb_array_length({args_sql[0]})'
#
#
# @dataclasses.dataclass
# class EFunctionNow(EFunction):
#     value = EFunctionType.now
#
#     def to_sql(self):
#         ...


T = typing.TypeVar('T', bound=ViewableEntity)


@dataclasses.dataclass
class ERoot(Node, typing.Generic[T]):
    type = NodeType.children
    entity: typing.Type[T]
    alias: str = None
    attrs: typing.List[Node] = dataclasses.field(default_factory=list)
    vars: typing.List[Node] = dataclasses.field(default_factory=list)
    filter: typing.List[Node] = dataclasses.field(default_factory=list)
    order: typing.List[EOrder] = dataclasses.field(default_factory=list)
    limit: typing.Optional[int] = None
    offset: typing.Optional[int] = None

    def get_name(self) -> str:
        if self.alias:
            return self.alias
        if isinstance(self.entity, str):
            return self.entity
        return "result"

    def requires(self) -> typing.List[EColumn]:
        return []
