#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import dataclasses
import functools
from typing import Callable, Optional

import sqlalchemy
import sqlalchemy.orm
from sqlalchemy import ForeignKey, Column
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy_tools.entity_helpers.sqlalchemy_base import sqlalchemy_mapper_registry
from sqlalchemy_tools.utils.first import first

from .defs import SA_DATACLASS_METADATA_KEY
from .ensure_tablename import ensure_tablename
from .move import relation_to_descend
from .proxy import EntityProxy, ProxyStep
from ..entity import Entity

default_registry = sqlalchemy_mapper_registry


@functools.cache
def sqlalchemy_dataclass_entity(registry_or_class: sqlalchemy.orm.registry = default_registry, table_name: str = None):
    registry = default_registry
    class_ = None
    if isinstance(registry_or_class, sqlalchemy.orm.registry):
        registry = registry_or_class
    elif isinstance(registry_or_class, type):
        class_ = registry_or_class

    def inner(entity_type: type[Entity]) -> type[Entity]:
        entity_type.__sa_dataclass_metadata_key__ = SA_DATACLASS_METADATA_KEY
        ensure_tablename(entity_type, table_name)

        entity = dataclasses.dataclass(entity_type)
        mapped_entity = registry.mapped(entity)
        mapper = mapped_entity.__mapper__
        assert len(mapper.tables) == 1
        table = first(mapper.tables)
        _table_to_mapper[table] = mapper.entity

        for key in dir(mapped_entity):
            if key.startswith("_"):
                continue
            value = getattr(mapped_entity, key)
            if not isinstance(value, InstrumentedAttribute):
                continue

            value_property = value.property
            if isinstance(value_property, sqlalchemy.orm.relationships.Relationship):
                remote_class = value_property.entity.entity
                remote_column_to_local = {
                    remote_column.name: local_column.name
                    for local_column, remote_column in value_property.local_remote_pairs
                }
                # descension
                descension_proxy = EntityProxy(remote_class, [ProxyStep(key, remote_column_to_local)])
                relation_to_descend[value_property] = make_descend(descension_proxy)
                # value.descend = make_descend(descension_proxy)
            elif isinstance(value_property, sqlalchemy.orm.properties.ColumnProperty):
                assert len(value_property.columns) == 1
                column: Column = value_property.columns[0]
                assert len(column.foreign_keys) <= 1
                if len(column.foreign_keys):
                    foreign_key: ForeignKey = first(column.foreign_keys)
                    column.ascend = make_ascend(foreign_key)
            else:
                raise TypeError(f"unexpected type: {type(value_property)}")

        return mapped_entity

    if class_:
        return inner(class_)
    return inner


_table_to_mapper = {}


def make_descend(proxy: EntityProxy) -> Callable[[], EntityProxy]:
    def inner(alias: Optional[str] = None):
        return proxy.label(alias=alias)
    return inner


def make_ascend(foreign_key: ForeignKey) -> Callable[[], EntityProxy]:
    def inner(alias: Optional[str]):
        local_column_name = foreign_key.parent.name
        referred_column_name = foreign_key.column.name
        steps = [ProxyStep(
            local_column_to_remote={referred_column_name: local_column_name},
            name=local_column_name, alias=alias, reverse=True
        )]

        referred_table = foreign_key.constraint.referred_table
        referred_class = _table_to_mapper[referred_table]
        return EntityProxy(referred_class, steps, alias)
    return inner
