import logging
from dataclasses import dataclass
from typing import Dict, List, Union, Any, Type, NamedTuple, Tuple

from sqlalchemy import inspect, Column, ForeignKey
from sqlalchemy.orm.attributes import InstrumentedAttribute


logger = logging.getLogger(__name__)


@dataclass
class RelationDescription:
    table_name: str
    keys_to_select: List[str]
    distant_column: InstrumentedAttribute
    distant_primary_key: str
    remote_relation: 'RelationDescription' = None


class EntityDescription:
    def __init__(self,
                 primary_key: Tuple[InstrumentedAttribute],
                 column_name_to_column: Dict[str, InstrumentedAttribute],
                 one_to_any_relations: Dict[str, RelationDescription],
                 many_to_any_relations: Dict[str, RelationDescription]
                 ):
        self.primary_key = primary_key
        self.p_key = primary_key[0]
        self.one_to_any_columns = one_to_any_relations
        self.many_to_any_columns = many_to_any_relations
        self.relations = many_to_any_relations.copy()
        self.relations.update(one_to_any_relations)
        self.column_name_to_column = column_name_to_column

    def simple_column_names(self):
        return set(self.column_name_to_column.keys()) - set(self.relations.keys())


_entity_to_description = {}


def remove_entity_description(entity) -> bool:
    return _entity_to_description.pop(entity, None) is not None


def analyze_entity(entity, use_cache: bool = True) -> EntityDescription:
    global _entity_to_description

    if entity not in _entity_to_description or not use_cache:
        mapper = inspect(entity)

        primary_key = mapper.primary_key

        all_columns = {attr.key: getattr(entity, attr.key) for attr in mapper.attrs}

        one_to_any_relations = {}
        many_to_any_relations = {}
        for relation in mapper.relationships:
            foreign_key = relation.remote_side.pop()

            keys_to_select = []
            distant_primary_key = None
            target_table = relation.target
            for column in target_table.columns:
                if column.primary_key:
                    distant_primary_key = column.key
                keys_to_select.append(column.key)
            if relation.secondary is None:
                one_to_any_relations[relation.key] = RelationDescription(
                    target_table.name, keys_to_select, foreign_key, distant_primary_key)
            else:
                remote_table = RelationDescription(
                    target_table.name, keys_to_select, foreign_key, distant_primary_key)
                intermediary_table: List[tuple] = relation.local_remote_pairs
                keys_to_select = []
                for local_column, remote_column in intermediary_table:
                    if local_column.table.name == entity.__tablename__:
                        distant_primary_key = remote_column.key
                    keys_to_select.append(remote_column.key)
                many_to_any_relations[relation.key] = RelationDescription(
                    relation.secondary.name, keys_to_select, foreign_key, distant_primary_key, remote_table)

        _entity_to_description[entity] = EntityDescription(
            primary_key, all_columns, one_to_any_relations, many_to_any_relations)

    return _entity_to_description[entity]
