#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
from typing import Union, Type

import typing

# import sqlalchemy
import sqlalchemy.orm
from dict_caster.extras import first
from extended_logger import get_logger
from sqlalchemy.orm import InstrumentedAttribute

from .select_atoms import Node, EColumn, ELiteral, ERoot, EFunction, EAlias, EOrder, SqlContext
from .select_lower import LowerSelector, wrap_sql_as_list, generate_sql, wrap_sql_as_count
from ..components import EntityProxy, AttributeProxy, ProxyStep, descend, ascend
from ..viewable_entity import ViewableEntity

logger = get_logger(__name__)


@dataclasses.dataclass
class MiddleSelector(LowerSelector):
    step_to_sub_selector: dict[str, 'MiddleSelector'] = dataclasses.field(default_factory=dict)
    entity: Union[Type[ViewableEntity], ViewableEntity] = None

    @classmethod
    def init_with_entity(cls, entity: Union[Type[ViewableEntity], ViewableEntity],
                         local_column_to_remote: dict[str, str] = None) -> 'MiddleSelector':
        local_column_to_remote = local_column_to_remote or {}
        return MiddleSelector(
            entity=entity,
            table_name=entity.get_table_name(),
            local_column_to_remote=local_column_to_remote,
        )

    def get_sub_selector_for(self, step: ProxyStep) -> 'MiddleSelector':
        attribute: InstrumentedAttribute = getattr(self.entity, step.name)
        attribute_name = step.name if step.alias is None else step.alias
        if attribute_name not in self.step_to_sub_selector:
            logger.trace(f"sub selector {attribute_name} does not exist, create")
            if isinstance(attribute.property, sqlalchemy.orm.relationships.RelationshipProperty):
                next_entity_proxy: EntityProxy = descend(attribute)
            elif isinstance(attribute.property, sqlalchemy.orm.properties.ColumnProperty):
                assert step.alias is not None
                next_entity_proxy: EntityProxy = ascend(attribute, step.alias)
            else:
                raise TypeError(f"attempt to get SubSelector for unexpected type: {type(attribute)}")
            self.step_to_sub_selector[attribute_name] = MiddleSelector.init_with_entity(
                entity=next_entity_proxy.get_entity(), local_column_to_remote=step.local_column_to_remote
            )
        else:
            logger.trace(f"sub selector {attribute_name} already exists")
        return self.step_to_sub_selector[attribute_name]

# #
# # class NodeType(str, enum.Enum):
# #     root = 'root'
# #     # list = 'list'
# #     # object = 'object'
# #     children = 'children'
# #     # children = 'children'
# #     column = 'column'
# #     literal = 'literal'
# #     function = 'function'
# #     order = 'order'
# #     # tuple = 'tuple'
# #
# #
# # class EFunctionType(str, enum.Enum):
# #     eq = "eq"
# #     ge = "ge"
# #     gt = "gt"
# #     le = "le"
# #     lt = "lt"
# #     now = "now"
# #     cast = "cast"
# #
# # class EColumnType(str, enum.Enum):
# #     boolean = "boolean"
# #
# # @dataclasses.dataclass
# # class Node:
# #     type: NodeType = dataclasses.field(init=False)
# #
# #
# #
# # @dataclasses.dataclass
# # class MandatoryAlias:
# #     alias: str
# #
# #
# # @dataclasses.dataclass
# # class OptionalAlias(MandatoryAlias):
# #     alias: str = None
#
# # @dataclasses.dataclass
# # class EObject(Node):
# #     type = NodeType.object
# #     attrs: List[Node] = dataclasses.field(default_factory=list)
# #     filter_by: List[Node] = dataclasses.field(default_factory=list)
# #     alias: str = None
# #
# # @dataclasses.dataclass
# # class EList(EObject):
# #     type = NodeType.list
#
#
#
# @dataclasses.dataclass
# class EColumn(Node):
#     type = NodeType.column
#     value: str
#     alias: str = None
#
#     def get_name(self) -> str:
#         return self.alias or self.value
#
#     # def __post_init__(self):
#     #     if isinstance(self.value, ...):
#     #         self.value = ...
#
# @dataclasses.dataclass
# class ELiteral(Node):
#     type = NodeType.literal
#     value: typing.Any
#     alias: str = None
#
#     def get_name(self) -> str:
#         return self.alias
#
# @dataclasses.dataclass
# class EOrder(Node):
#     type = NodeType.order
#     value: EColumn
#     asc: bool = True
#
# @dataclasses.dataclass
# class EFunction(Node):
#     type = NodeType.function
#     value: EFunctionType = dataclasses.field(init=False)
#     args: typing.List[typing.Any] = dataclasses.field(default_factory=list)
#     kwargs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
#
#     def to_sql(self):
#         ...
# #
# # @dataclasses.dataclass
# # class EAnyFunction(Node):
# #     type = NodeType.function
# #     value: EFunctionType
# #     args: typing.List[typing.Any] = dataclasses.field(default_factory=list)
# #     kwargs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
#
#
# @dataclasses.dataclass
# class EFunctionLE(EFunction):
#     value = EFunctionType.le
#
#     def to_sql(self):
#         ...
#
# @dataclasses.dataclass
# class EFunctionEQ(EFunction):
#     value = EFunctionType.eq
#
#     def to_sql(self):
#         ...
#
# @dataclasses.dataclass
# class EFunctionGE(EFunction):
#     value = EFunctionType.ge
#
#     def to_sql(self):
#         ...
#
# @dataclasses.dataclass
# class EFunctionCast(EFunction):
#     value = EFunctionType.cast
#
#     def to_sql(self):
#         ...
#
# @dataclasses.dataclass
# class EFunctionLen(EFunction):
#     value = EFunctionType.cast
#
#     def to_sql(self):
#         ...
#
# @dataclasses.dataclass
# class EFunctionNow(EFunction):
#     value = EFunctionType.now
#
#     def to_sql(self):
#         ...
#
# @dataclasses.dataclass
# class ERoot(Node):
#     type = NodeType.children
#     entity: typing.Union[typing.Type[ViewableEntity], str]
#     alias: str = None
#     attrs: typing.List[Node] = dataclasses.field(default_factory=list)
#     filter: typing.List[Node] = dataclasses.field(default_factory=list)
#     order: typing.List[EOrder] = dataclasses.field(default_factory=list)
#
#     def get_name(self) -> str:
#         if self.alias:
#             return self.alias
#         if isinstance(self.entity, str):
#             return self.entity
#         return "result"
#
# @dataclasses.dataclass
# class EParent(Node):
#     type = NodeType.object
#     alias: str #  is mandatory
#     attrs: typing.List[Node] = dataclasses.field(default_factory=list)
#     filter: typing.List[Node] = dataclasses.field(default_factory=list)


#
# @dataclasses.dataclass
# class EChildren(Node):
#     type = NodeType.list
#     relation: str
#     alias: str = None
#     attrs: typing.List[Node] = dataclasses.field(default_factory=list)
#     filter: typing.List[Node] = dataclasses.field(default_factory=list)
#     order: typing.List[EOrder] = dataclasses.field(default_factory=list)
#
#     def get_name(self) -> str:
#         return self.alias or self.relation

#
#
# Q = ERoot(
#     Contract,
#     attrs=[
#         EColumn("id"),
#         EColumn("id", alias="id2"),
#         ELiteral(1, alias="one"),
#         EColumn("name"),
#         ERoot(
#             "stages",
#             attrs=[EColumn("id"), EColumn("start_at"), EColumn("end_at")],
#             order=[EOrder(EColumn("start_at"), asc=False)]
#         ),
#         ERoot(
#             "stages",
#             attrs=[EColumn("id"), EColumn("start_at"), EColumn("end_at")],
#             order=[EOrder(EColumn("start_at"), asc=False)],
#             alias="active_stages2",
#         ),
#         ERoot(
#             "stages",
#             filter=[
#                 EFunctionLE([EColumn("start_at"), EFunctionNow()]),
#                 EFunctionGE([EColumn("end_at"), EFunctionNow()]),
#             ],
#             alias="active_stages",
#         ),
#         # EParent(
#         #     "region_id", Region, ...
#         # )
#     ],
#     # filter=[
#     #     EColumn("is_active"),
#     #     ELiteral(True),
#     #     EFunctionEQ([ELiteral(1), ELiteral(1)]),
#     #     EFunctionCast([ELiteral(1), EColumnType.boolean]),
#     #     EFunctionGE([EFunctionLen(["active_stages"]), 1]),
#     #
#     # ]
# )

def process_attribute(selector: MiddleSelector, attribute: Node):
    if isinstance(attribute, EAlias):
        selector.result_name_to_ecolumn[attribute.get_name()] = attribute.value
    elif isinstance(attribute, EColumn):
        selector.result_name_to_ecolumn[attribute.get_name()] = attribute
    elif isinstance(attribute, ELiteral):
        assert False
        # assert attribute.alias is not None
        pass
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, ERoot):
        relation = getattr(selector.entity, attribute.entity)
        proxy: EntityProxy = descend(relation, attribute.get_name())
        selector = selector.get_sub_selector_for(first(proxy.get_path()))
        get_sub_selector_from_nodes(attribute, selector)
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, EntityProxy):
        proxy: EntityProxy = attribute
        selector = selector.get_sub_selector_for(first(proxy.get_path()))
        get_sub_selector_from_nodes(ERoot(proxy.get_entity()), selector)
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, AttributeProxy):
        # proxy: EntityProxy = attribute
        selector = selector.get_sub_selector_for(first(attribute.path))
        get_sub_selector_from_nodes(ERoot(attribute.attribute.class_, attrs=[attribute.attribute]), selector)
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, EFunction):
        name = attribute.to_sql(SqlContext(selector.table_name))
        selector.result_name_to_ecolumn[name] = attribute
    else:
        logger.warning(f'process_attribute>- unexpected attribute: %s(%s)', attribute, type(attribute))
        assert False


def process_variable(selector: MiddleSelector, variable: Node):
    assert isinstance(variable, ERoot)
    relation = getattr(selector.entity, variable.entity)
    proxy: EntityProxy = descend(relation, variable.get_name())
    selector = selector.get_sub_selector_for(first(proxy.get_path()))
    selector.is_hidden = True
    get_sub_selector_from_nodes(variable, selector)
    # selector.result_name_to_column[attribute.get_name()] = attribute.value


def process_filter(selector: LowerSelector, filter_: Node):
    assert isinstance(filter_, EFunction)
    selector.filter_by.append(filter_)


def process_order(selector: LowerSelector, order: EOrder):
    assert isinstance(order, EOrder)
    selector.order.append(order)


def get_sub_selector_from_nodes(root: ERoot, selector: LowerSelector = None) -> LowerSelector:

    # table_name: str = root.entity.get_table_name()
    # print(f'get_sub_selector_from_nodes>- root={root}, selector={selector}')

    if selector is None:
        assert isinstance(root.entity, type) and issubclass(root.entity, ViewableEntity)
        selector = MiddleSelector.init_with_entity(entity=root.entity)

    # print(f'sql_from_selector>- selector.select: {root.attrs}')
    for argument in root.attrs:
        process_attribute(selector, argument)
    for var in root.vars:
        process_variable(selector, var)
    for filter_ in root.filter:
        process_filter(selector, filter_)
    for order in root.order:
        process_order(selector, order)
    selector.limit = root.limit
    selector.offset = root.offset
    # print(f'sql_from_selector>- selector.filter_by: {root.filter}')
    # for condition in selector.filter_by:
    #     process_condition(selector, condition)
    return selector


def e_root_to_sql(root: ERoot, count: bool = False) -> typing.Tuple[str, dict]:
    lower_selector = get_sub_selector_from_nodes(root)
    sql_parameter_key_to_value = {}
    inner = generate_sql("", "", lower_selector, sql_parameter_key_to_value)
    if count:
        sql = wrap_sql_as_count(inner)
    else:
        sql = wrap_sql_as_list(inner)
    return sql, sql_parameter_key_to_value
#
# if __name__ == "__main__":
#     from entities.agent import Agent
#     # required to create links
#     from entities.playlist import Playlist
#     from entities.playlist_to_content_block import PlaylistToContentBlock
#     from entities.content_block import ContentBlock
#     from controller_module.selector_lower import align_sql
#
#     Q = ERoot(
#         Agent,
#         attrs=[
#             EAlias(EFunctionLen([EColumn("screens")]), alias="screens_amount"),
#             EColumn("id"), EAlias(EColumn("id"), alias="id2"), EAlias(ELiteral(1), alias="one"), EColumn("key"),
#             ERoot(
#                 "screens",
#                 attrs=[
#                     EColumn("id"), EAlias(EColumn("key"), alias="screen_key"), EAlias(ELiteral(2), alias="two"),
#                     ERoot("playlists", attrs=[EColumn("layout_x"), EColumn("layout_y")])
#                 ],
#                 # filter=[EFunctionLE([EColumn("id"), ELiteral(10)])]
#                 # order=[EOrder(EColumn("id"), asc=False)]
#             ),
#             # ERoot(
#             #     "stages",
#             #     attrs=[EColumn("id"), EColumn("start_at"), EColumn("end_at")],
#             #     order=[EOrder(EColumn("start_at"), asc=False)],
#             #     alias="active_stages2",
#             # ),
#         ],
#         vars=[
#             ERoot(
#                 "screens", attrs=[ERoot("playlists")],
#                 alias="active_screens",
#                 filter=[
#                     EFunctionGE([EFunctionLen([EColumn("playlists")]), ELiteral(1)])
#                 ],
#             )
#         ],
#         filter=[
#             EFunctionGE([EFunctionLen([EColumn("active_screens")]), ELiteral(1)])
#         ],
#         order=[EOrder(EFunctionLen([EColumn("screens")]), asc=True)]
#     )
#     print()
#     print()
#     print()
#     print()
#     print()
#     sub = get_sub_selector_from_nodes(Q)
#     print()
#     print()
#     print()
#     print()
#     print()
#     print(sub)
#
#     res = something4("", "", sub)
#     print(align_sql(res))
#
#
#     Q2 = ERoot(
#         Organisation,
#         attrs=[EColumn("id"), EColumn("title")],
#         vars=[ERoot("members", alias="members", filter=[EFunctionEQ([EColumn("user_id"), ELiteral(42)])])],
#         filter=[
#             EFunctionGT([EFunctionLen([EColumn("members")]), ELiteral(0)]),
#             EFunctionGT([EColumn("created_at"), ELiteral(0)])
#         ],
#         order=[EOrder(EColumn("created_at"))]
#     )
#     print()
#     print()
#     print()
#     print()
#     print(align_sql(something4("", "", get_sub_selector_from_nodes(Q2))))
