#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import functools
from dataclasses import dataclass, field
from typing import TypeVar, Self, Generic

import sqlalchemy.orm
from entity_read.entity import Entity

from .expression import Expression
from .order import Order
from .query_token import QueryToken
from .remote_entity import RemoteEntity


EntityType = TypeVar('EntityType', bound=RemoteEntity | str, covariant=True)


@dataclass(frozen=True, kw_only=True, repr=False)
class Relation(QueryToken, Generic[EntityType]):
    key: str
    parent: RemoteEntity | None = field(hash=False, default=None)

    def __class_getitem__(cls, item: type[EntityType]):
        name = f"{cls.__name__}[{item.__name__}]"
        return type(name, (cls,), {"_entity_type": item})

    def __call__(self, parent: RemoteEntity | None) -> Self:
        return type(self)(key=self.key, parent=parent)

    def shortcut(self) -> str:
        return f"rel.{self.key}"

    def shortcut_over(self) -> str:
        return self.shortcut()

    def get_parent(self) -> RemoteEntity | None:
        return self.parent

    def __str__(self) -> str:
        cls = type(self)
        return (f'{cls.__name__}{f"[{self._entity.__name__}]" if self._entity else ""}'
                f'(parent={self.parent},key={self.key!r})')

    @functools.cached_property
    def _type_arg(self) -> type[EntityType]:
        return getattr(self, "_entity_type")

    __repr__ = __str__

    @functools.cached_property
    def _entity(self) -> type[EntityType]:
        return self._type_arg

    @functools.cached_property
    def related(self) -> EntityType:
        return self._type_arg(parent=self)

    def get_related_entity(self, entity_type: type[Entity]) -> type[Entity]:
        relation = self._get_relation(entity_type)
        result = relation.entity.entity
        return result

    def get_local_column_name_to_remote(self, entity_type: type[Entity]) -> dict[str, str]:
        relation = self._get_relation(entity_type)
        return {local_column.name: remote_column.name for local_column, remote_column in relation.local_remote_pairs}

    def get_remote_column_name_to_local(self, entity_type: type[Entity]) -> dict[str, str]:
        relation = self._get_relation(entity_type)
        return {remote_column.name: local_column.name for local_column, remote_column in relation.local_remote_pairs}

    def _get_relation(self, entity_type: type[Entity]) -> sqlalchemy.orm.Relationship:
        if (relation := entity_type.get_key_to_relation().get(self.key)) is None:
            raise KeyError(f"Not found relation {self.key!r} in {entity_type.__name__}"
                           f"({entity_type.get_relation_names()})")
        return relation

    def subquery(
            self,
            attrs: list[Expression] | Expression | None = None,
            vars: dict[str, Expression] | None = None,
            filters: list[Expression] | None = None,
            searches: list[Expression] | None = None,
            orders: list[Order] | None = None,
            limit: int | None = None,
            offset: int | None = None,
    ) -> 'SubQuery':
        from .subquery import SubQuery
        return SubQuery(
            over=self, attrs=attrs or [], vars=vars or {}, filters=filters or [], searches=searches or [],
            orders=orders or [], limit=limit, offset=offset
        )

    def exists(
            self,
            attrs: list[Expression] | Expression | None = None,
            vars: dict[str, Expression] | None = None,
            filters: list[Expression] | None = None,
            searches: list[Expression] | None = None,
            orders: list[Order] | None = None,
            limit: int | None = None,
            offset: int | None = None,
    ) -> 'Exists':
        from .function import Exists
        return Exists(self.subquery(
            attrs=attrs, vars=vars, filters=filters, searches=searches, orders=orders, limit=limit, offset=offset
        ))
