#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
from typing import Optional, NamedTuple

import sqlalchemy.orm
from dict_caster.extras import first
from sqlalchemy import ForeignKey
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.elements import NamedColumn
from sqlalchemy_tools.entity_helpers.entity_name_to_class import add_to_entity_name_to_class, get_entity_by_name
from sqlalchemy_tools.entity_helpers.sqlalchemy_base import sqlalchemy_mapper_registry
from sqlalchemy_tools.utils.text import camel_case_to_underscore


def ensure_tablename(class_: type, table_name: str = None):
    if table_name:
        class_.__tablename__ = table_name
    if not hasattr(class_, '__tablename__'):
        class_.__tablename__ = camel_case_to_underscore(class_.__name__)


table_to_mapper = {}


class ProxyStep(NamedTuple):
    name: str
    local_column_to_remote: dict[str, str]
    alias: str = None
    reverse: bool = False


@dataclasses.dataclass
class AttributeProxy:
    name: str
    attribute: InstrumentedAttribute
    path: list[ProxyStep] = dataclasses.field(default_factory=list)
    alias: Optional[str] = None

    def label(self, alias: Optional[str]):
        return AttributeProxy(name=self.name, attribute=self.attribute, path=self.path[:], alias=alias)


class EntityProxy:
    def __init__(self, entity, path: list[ProxyStep], alias: Optional[str] = None):
        self._path = path
        self._entity = entity
        if alias:
            last_step_as_dict = self._path[-1]._asdict()
            last_step_as_dict["alias"] = alias
            self._path[-1] = ProxyStep(**last_step_as_dict)
        # self._alias = alias
        # print("init EntityProxy", self)

    def get_entity(self):
        return self._entity

    def get_path(self) -> list[ProxyStep]:
        return self._path

    def get_alias(self) -> Optional[str]:
        return self._alias

    def prepend_steps(self, steps) -> 'EntityProxy':
        # print("prepend_steps", steps, self._path)
        return EntityProxy(self._entity, steps + self._path)

    def __getattr__(self, item: str):
        result = getattr(self._entity, item)
        if isinstance(result, InstrumentedAttribute):
            return AttributeProxy(item, result, self._path[:])
        return object.__getattribute__(self, item)

    def __str__(self):
        cls = type(self)
        return f'{cls.__name__}(entity={self._entity}, path={self._path})'

    def label(self, alias: Optional[str]):
        return EntityProxy(entity=self._entity, path=self._path[:], alias=alias)


def descend(attribute: InstrumentedAttribute | AttributeProxy, alias: Optional[str] = None) -> EntityProxy:
    if isinstance(attribute, AttributeProxy):
        entity_proxy = descend(attribute.attribute, alias=alias)
        return entity_proxy.prepend_steps(attribute.path)

    property_ = attribute.property
    if isinstance(property_, sqlalchemy.orm.relationships.RelationshipProperty):
        remote_class = property_.entity.entity
        remote_column_to_local = {
            remote_column.name: local_column.name
            for local_column, remote_column in property_.local_remote_pairs
        }
        entity_proxy = EntityProxy(remote_class, [ProxyStep(attribute.key, remote_column_to_local)])
        return entity_proxy.label(alias=alias)

    raise KeyError(f"Not found a way to descend from {attribute}") from None


def ascend(attribute: InstrumentedAttribute, alias: str) -> EntityProxy:
    if isinstance(attribute, AttributeProxy):
        entity_proxy = ascend(attribute.attribute, alias=alias)
        return entity_proxy.prepend_steps(attribute.path)

    property_ = attribute.property
    if isinstance(property_, sqlalchemy.orm.properties.ColumnProperty):
        assert len(property_.columns) == 1, \
            f'failed to ascend from {attribute=} with {property_=} having {len(property_.columns)=}, expected 1'
        column: NamedColumn = property_.columns[0]
        assert len(column.foreign_keys) == 1, \
            f"failed to ascend from {column=} with {len(column.foreign_keys)=}, expected 1"
        foreign_key: ForeignKey = first(column.foreign_keys)
        local_column_name = foreign_key.parent.name
        referred_column_name = foreign_key.column.name
        steps = [ProxyStep(
            local_column_to_remote={referred_column_name: local_column_name},
            name=local_column_name, alias=alias, reverse=True
        )]
        referred_table = foreign_key.constraint.referred_table
        referred_class = table_to_mapper[referred_table]
        return EntityProxy(referred_class, steps, alias)

    raise KeyError(f"Not found a way to ascend from {attribute}. ascend works on columns with FK") from None


default_registry = sqlalchemy_mapper_registry


def sqlalchemy_dataclass2(registry_or_class: sqlalchemy.orm.registry = default_registry, table_name: str = None):
    registry = default_registry
    class_ = None
    if isinstance(registry_or_class, sqlalchemy.orm.registry):
        registry = registry_or_class
    elif isinstance(registry_or_class, type):
        class_ = registry_or_class

    def inner(class__: type):
        ensure_tablename(class__, table_name)
        # noinspection PyUnresolvedReferences
        if (mapped := get_entity_by_name(class__.__tablename__, None)) is not None:
            return mapped

        # noinspection PyTypeChecker
        mapped_entity = registry.mapped_as_dataclass(class__)
        add_to_entity_name_to_class(class__)
        mapper = mapped_entity.__mapper__
        assert len(mapper.tables) == 1
        table = first(mapper.tables)
        table_to_mapper[table] = mapper.entity
        return mapped_entity

    if class_:
        return inner(class_)
    return inner


def sqlalchemy_dataclass(registry_or_class: sqlalchemy.orm.registry = default_registry, table_name: str = None):
    registry = default_registry
    class_ = None
    if isinstance(registry_or_class, sqlalchemy.orm.registry):
        registry = registry_or_class
    elif isinstance(registry_or_class, type):
        class_ = registry_or_class

    def inner(class__: type):
        class__.__sa_dataclass_metadata_key__ = "sa"
        ensure_tablename(class__, table_name)

        # noinspection PyTypeChecker
        entity = dataclasses.dataclass(class__)
        add_to_entity_name_to_class(entity)
        mapped_entity = registry.mapped(entity)
        mapper = mapped_entity.__mapper__
        assert len(mapper.tables) == 1
        table = first(mapper.tables)
        table_to_mapper[table] = mapper.entity
        return mapped_entity

    if class_:
        return inner(class_)
    return inner


# def sql_relation_field(remote_class, use_list: bool = True, secondary=None, order_by=None):
#     # default_factory = list if use_list else dict
#     relation = sqlalchemy.orm.relationship(remote_class, uselist=use_list, secondary=secondary, order_by=order_by)
#     # relation.descend = lambda: print("HI")
#     metadata: dict[str, sqlalchemy.orm.RelationshipProperty] = {"sa": relation}
#     # kwargs = {"metadata": metadata, "default_factory": default_factory}
#     kwargs = {"metadata": metadata}
#     return dataclasses.field(**kwargs, default_factory=INoValue)

#
# def sql_field(type_: Union[Type[TypeEngine], TypeEngine], foreign_key: ForeignKey = None,
#               *, primary_key: bool = False, nullable: bool = True, unique: bool = False, index: bool = False,
#               default: Any = None, server_default: Any = None, default_factory: Callable = None) -> dataclasses.field:
#     # if primary_key:
#     #     nullable = unique = True
#     column_args = [type_, foreign_key]
#     column_kwargs = {"primary_key": primary_key} if primary_key else {"nullable": nullable, "unique": unique}
#     # if default is not None:
#     #     column_kwargs['default'] = NoValue
#     # column_kwargs['default'] = NoValue
#     if index:
#         column_kwargs['index'] = index
#     if server_default is not None:
#         column_kwargs['server_default'] = server_default
#     metadata = {"sa": Column(*column_args, **column_kwargs)}
#
#     kwargs = {"metadata": metadata}
#     if default_factory:
#         kwargs['default_factory'] = default_factory
#     # if nullable or default:
#     kwargs['default'] = NoValue
#     # kwargs['default'] = None if default is None else default
#     # if primary_key:
#     #     kwargs['init'] = False
#     return dataclasses.field(**kwargs)
