r"""

   | User | SpoAdmin | RegAdmin |
----------------------------------------------------------
   |  X   |    X     |    X     | create_contract_proposal - entity + create votes
   |  X   |    X     |    X     | update_contract_proposal - entity + chat??
   |  X   |    X     |    X     | delete_contract_proposal - entity
   |  X   |    X     |    X     | count_contract_proposal - entity
   |  X   |    X     |    X     | list_contract_proposal - entity
   |  X   |    X     |    X     | get_contract_proposal - entity

   |  X   |          |          | begin
   |      |    X     |          | rollback
   |      |    X     |          | invite_regional_admin
   |      |          |    X     | approve
   |      |    X     |          | commit


DB
          /--> vote
process--|
          \--> contract_proposal (entity)


           /--> vote
entity->|
           \--> contract_upsert (entity)

"""
import contextlib
import copy
import dataclasses
import itertools
from collections import defaultdict
from typing import Any, Optional, Union, Iterable, Type, TypeVar

import sqlalchemy
from async_tools import  AsyncInitable
from init_helpers.dict_to_dataclass import NoValue, dict_to_dataclass
from sqlalchemy import select, Table
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy_tools.database_connector.database_connector import DatabaseConnector
from sqlalchemy_tools.entity_helpers.sqlalchemy_base import Values
from dict_caster.extras import first, to_list

from .entity_view.viewable_entity import ViewableEntity, SelectedView, get_related_entity_type, get_remote_column, \
    get_column_pairs_from_relation, DeleteView, DeltaView, InsertView
from .entity_view.layers.select_atoms import ERoot
from .entity_view.layers.select_middle import e_root_to_sql
from extended_logger import get_logger
from .entity_field import NoDefault

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
from .utils import split_by_type

logger = get_logger(__name__)
T = TypeVar('T', bound=ViewableEntity)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


class NotFound(KeyError):
    pass


class MultipleFound(ValueError):
    pass


class EntityDatabase(AsyncInitable):

    @dataclasses.dataclass
    class Context:
        database_connector: DatabaseConnector

    async def _on_start(self) -> None:
        pass

    def __init__(self, context: Context):
        self.context = context
        super(EntityDatabase, self).__init__()

    async def insert_plain_entity(self, entity: ViewableEntity, database_session: AsyncSession) -> int:
        return await self.insert_plain_entities([entity], database_session)

    async def insert_plain_entities(self, plain_entities: list[ViewableEntity], database_session: AsyncSession) -> int:
        table_with_keys_to_entities = defaultdict(list)
        table_with_keys_to_entity_dicts = defaultdict(list)
        for entity in plain_entities:
            await entity.resolve_insert_defaults(self.context)
            table = entity.get_table()
            entity_dict = entity.as_dict(plain=True)
            keys = tuple(sorted(entity_dict))
            table_with_keys_to_entity_dicts[(table, keys)].append(entity_dict)
            table_with_keys_to_entities[(table, keys)].append(entity)

        for table_with_keys, entity_dicts in table_with_keys_to_entity_dicts.items():
            table, keys = table_with_keys
            entities = table_with_keys_to_entities[table_with_keys]
            primary_keys = first(entities).get_primary_key_columns()
            inserted_primary_key_rows = await self._insert_plain_entities_with_same_keys(
                table, entity_dicts, primary_keys, database_session,
            )
            for entity, row in zip(entities, inserted_primary_key_rows):
                for primary_key_column, value in zip(primary_keys, row):
                    setattr(entity, primary_key_column.name, value)

        return len(plain_entities)

    @classmethod
    async def _insert_plain_entities_with_same_keys(cls, table: Table, entity_dicts: list[dict], returning: list,
                                                    database_session: AsyncSession) -> list[tuple]:
        statement = sqlalchemy.insert(table).values(entity_dicts).returning(*returning)
        answer = await database_session.execute(statement)
        return [tuple(row) for row in answer]

    @classmethod
    async def plain_update(cls,
                           entity: Type[ViewableEntity],
                           column_name_to_value_list: list[dict[str, Any]],
                           db_session: AsyncSession
                           ) -> int:

        simple_column_names = entity.get_column_names()
        column_names_tuple_to_values_tuple = defaultdict(list)
        for column_name_to_value in column_name_to_value_list:
            needed_column_name_to_value = {
                c_name: val for c_name, val in column_name_to_value.items() if c_name in simple_column_names
            }
            column_names, values = tuple(), tuple()
            for column_name, value in sorted(needed_column_name_to_value.items()):
                column_names += (column_name,)
                values += (value,)
            column_names_tuple_to_values_tuple[column_names].append(values)

        columns_to_map = entity.get_primary_key_columns()
        column_names_to_map = {column.name for column in columns_to_map}
        updated_rows_amount = 0
        for column_name_tuple, values in column_names_tuple_to_values_tuple.items():
            column_name_set = set(column_name_tuple)
            if column_names_to_map > column_name_set:
                raise ValueError(f'got {type(entity).__name__} to update '
                                 f'without necessary column_names: {column_names_to_map}, '
                                 f'got only {column_name_set}')
            column_names_to_update = column_name_set - column_names_to_map

            if not column_names_to_update:
                continue

            columns = list(map(lambda x: getattr(entity, x), column_name_tuple))
            sql_values = Values(values, columns).alias('states')
            key_name_to_column = {
                key: getattr(sql_values.c, f'column{i + 1}') for i, key in enumerate(column_name_tuple)
            }

            set_values = {}
            for name in column_names_to_update:
                set_values[name] = key_name_to_column[name]

            wheres = [column == key_name_to_column[column.name] for column in columns_to_map]
            stmt = entity.get_table().update().values(**set_values)
            for where in wheres:
                stmt = stmt.where(where)

            answer = await db_session.execute(stmt)
            updated_rows_amount += answer.rowcount

        return updated_rows_amount

    async def insert_cascade(self, cascade_entities: list[ViewableEntity], database_session: AsyncSession):
        entity_type_to_entities = split_by_type(cascade_entities)

        result = None
        for entity_type, entities in entity_type_to_entities.items():
            await self.insert_plain_entities(entities, database_session)

            for related_entity_key, relation in first(entities).get_key_to_relation().items():
                relation_column_pairs = get_column_pairs_from_relation(relation)
                for entity in entities:
                    relation_value = getattr(entity, related_entity_key)
                    if relation_value == NoValue or relation_value == NoDefault:
                        continue
                    related_entities = to_list(relation_value)
                    for relation_local_column, relation_remote_column in relation_column_pairs:
                        relation_local_column_value = getattr(entity, relation_local_column.key)
                        for elated_entity in related_entities:
                            setattr(elated_entity, relation_remote_column.key, relation_local_column_value)
                    if related_entities:
                        await self.insert_cascade(related_entities, database_session)

        return result

    @staticmethod
    def patch_children(
            parent: ViewableEntity, children: list[ViewableEntity],
            relation_column_pairs: list[tuple[sqlalchemy.Column, sqlalchemy.Column]]
    ) -> None:
        for relation_local_column, relation_remote_column in relation_column_pairs:
            relation_remote_column_key = relation_remote_column.key
            relation_local_column_value = getattr(parent, relation_local_column.key)
            for child in children:
                setattr(child, relation_remote_column_key, relation_local_column_value)

    @staticmethod
    async def delete_by_delete_views(delete_views: list['ViewableEntity[DeleteView]'], db_session: AsyncSession) -> None:
        entity_type_to_views = {}
        for delta in delete_views:
            entity_type_to_views.setdefault(type(delta), []).append(delta)

        for entity_type, views in entity_type_to_views.items():
            primary_key_columns = entity_type.get_primary_key_columns()
            primary_keys_values = [v.get_primary_key_values() for v in views]
            prepared_primary_keys = sqlalchemy.tuple_(*primary_key_columns)
            stmt = sqlalchemy.delete(entity_type.get_entity_type())\
                .where(prepared_primary_keys.in_(primary_keys_values))
            await db_session.execute(stmt)

    async def update_by_delta_views(self, delta_views: list['ViewableEntity[DeltaView]'], db_session: AsyncSession) -> None:
        entity_type_to_deltas = {}
        for delta in copy.deepcopy(delta_views):
            entity_type_to_deltas.setdefault(type(delta), []).append(delta)

        for view_type, deltas in entity_type_to_deltas.items():
            entity_type = view_type.get_entity_type()
            await self.plain_update(
                entity_type,
                [
                    {key: val.new if hasattr(val, 'new') else val for key, val in obj.as_dict(plain=True).items()}
                    for obj in deltas
                ],
                db_session
            )

            entities_to_create: list[ViewableEntity[InsertView]] = []
            entities_to_update: list[ViewableEntity[DeltaView]] = []
            entities_to_delete: list[ViewableEntity[DeleteView]] = []
            for relation_name, relation in entity_type.get_key_to_relation().items():
                relation_column_pairs = get_column_pairs_from_relation(relation)
                for delta in deltas:
                    relation_value = getattr(delta, relation_name)
                    if relation_value in (NoValue, NoDefault):
                        continue
                    if relation_value.create:
                        self.patch_children(parent=delta, children=relation_value.create,
                                            relation_column_pairs=relation_column_pairs)
                        entities_to_create += relation_value.create
                    if relation_value.update:
                        self.patch_children(parent=delta, children=relation_value.update,
                                            relation_column_pairs=relation_column_pairs)
                        entities_to_update += relation_value.update
                    if relation_value.delete:
                        self.patch_children(parent=delta, children=relation_value.delete,
                                            relation_column_pairs=relation_column_pairs)
                        entities_to_delete += relation_value.delete
                    # # entities_to_create += [entity_type(view.as_dict()) for view in relation_value.create]

            await self.delete_by_delete_views(entities_to_delete, db_session)
            await self.update_by_delta_views(entities_to_update, db_session)
            await self.insert_cascade(entities_to_create, db_session)

    async def update(self, entity_type: Type[ViewableEntity], cascade_entities: list[ViewableEntity],
                     db_session: AsyncSession) -> tuple[int, int, int]:
        primary_key_values = [entity.get_primary_key_values() for entity in cascade_entities]
        # primary_key_columns =
        # filters = [entity_type.get_primary_key_columns()[0].in_(primary_key_values[0])]
        filters = [sqlalchemy.tuple_(*entity_type.get_primary_key_columns()).in_(primary_key_values)]
        updated_status = await self.update_cascade(
            entity_type, cascade_entities, filters, db_session, insert_allowed=False
        )
        return updated_status

    async def update_cascade(self, entity_type: Type[ViewableEntity], cascade_entities: list[ViewableEntity],
                             # column_name_to_value_list: list[dict[str, Any]],
                             filters: list[BinaryExpression], db_session: AsyncSession,
                             insert_allowed: bool = True) -> tuple[int, int, int]:
        answer = await db_session.execute(select(entity_type.get_primary_key_columns()).where(*filters))
        existent_p_keys = [tuple(row) for row in answer.all()]
        p_keys_for_delete = existent_p_keys.copy()

        entities_to_create: list[ViewableEntity] = []
        entities_to_update: list[ViewableEntity] = []
        for entity in cascade_entities:
            p_key = entity.get_primary_key_values()
            # p_key = column_name_to_value.get(entity.p_key.name)
            if p_key and p_key in existent_p_keys:
                entities_to_update.append(entity)
                p_keys_for_delete.remove(p_key)
            elif insert_allowed:
                entities_to_create.append(entity)

        deleted_rows_amount = await self.delete_by_primary_keys(entity_type, p_keys_for_delete, db_session)
        # updated_status += await cls.delete_by_primary_keys(entity_type, p_keys_for_delete, db_session)

        entity_type_to_states = {}
        entity_type_to_filter = {}

        logger.debug(f"entities_to_update: {entities_to_update}")
        updated_rows_amount = await self.plain_update(
            entity_type, [obj.as_dict(plain=True) for obj in entities_to_update], db_session)
        logger.debug(f"entities_to_create: {entities_to_create}")
        created_rows_amount = await self.insert_plain_entities(entities_to_create, db_session)
        # new_object_ids: list[int] = await self._insert_plain_entities(entities_to_create, db_session)

        # updated_status += bool(updated_rows_amount) or bool(new_object_ids)

        for entity in itertools.chain(entities_to_create, entities_to_update):
            object_primary_key_values = entity.get_primary_key_values()

            for relation_column_name, relation in entity.get_key_to_relation().items():
                distant_entity = get_related_entity_type(relation)
                distant_column = get_remote_column(relation)
                # TODO: think about it, its composite foreign key
                # prepared_keys = [val[0] for val in existent_p_keys]
                entity_type_to_filter[distant_entity] = [sqlalchemy.tuple_(distant_column).in_(existent_p_keys)]

                relation_values = getattr(entity, relation_column_name, NoValue)
                if relation_values != NoValue and relation_values != NoDefault and relation_values is not None:
                    if not relation_values:
                        relation_values = []
                    elif not isinstance(relation_values, list):
                        relation_values = [relation_values]
                    for value in relation_values:
                        # logger.debug(f'Set {entity.get_primary_key_names()} to {object_primary_key_values} in {value}')
                        # value.set_primary_key_values(object_primary_key_values)
                        setattr(value, distant_column.key, object_primary_key_values[0])
                    entity_type_to_states.setdefault(distant_entity, []).extend(relation_values)

        for relation_entity_type in entity_type_to_states:
            states = entity_type_to_states[relation_entity_type]
            filters = entity_type_to_filter[relation_entity_type]
            await self.update_cascade(relation_entity_type, states, filters, db_session)
            # updated_status += await cls.update_cascade(relation_entity_type, states, filters, db_session)

        return created_rows_amount, updated_rows_amount, deleted_rows_amount

    @classmethod
    async def delete_by_primary_keys(cls, entity_type: Type[ViewableEntity],
                                     primary_keys_values: Iterable[Union[Any, Iterable[Any]]],
                                     database_session: AsyncSession) -> int:
        if not primary_keys_values:
            return 0
        primary_key_columns = entity_type.get_primary_key_columns()
        prepared_primary_keys_values = []
        if True:
            prepared_primary_keys = sqlalchemy.tuple_(*primary_key_columns)
            for values in primary_keys_values:
                if len(values) != len(primary_key_columns):
                    raise ValueError(f"entity {entity_type} has {len(primary_key_columns)} primary keys: "
                                     f"{primary_key_columns}, but got {len(values)} values: {values}")
                prepared_primary_keys_values.append(tuple(values))
        stmt = sqlalchemy.delete(entity_type) \
            .where(prepared_primary_keys.in_(prepared_primary_keys_values)) \
            .returning(*primary_key_columns)
        answer = await database_session.execute(stmt)

        deleted_primary_keys = [row[0] for row in answer.fetchall()]
        deleted_rows_amount = len(deleted_primary_keys)
        return deleted_rows_amount

    async def select_raw(self, sql_code: str, params, database_session: Optional[AsyncSession] = None) -> Any:
        if database_session is None:
            context = database_session = self.context.database_connector.get_session()
        else:
            context = contextlib.AsyncExitStack()
        async with context:
            logger.warning("start select")
            result: sqlalchemy.engine.CursorResult = await database_session.execute(sqlalchemy.text(sql_code), params)
            result = result.scalar()
            logger.warning("select result got")

        return result

    async def select_e_root(self, e_root: ERoot[T], database_session: Optional[AsyncSession] = None) -> list[dict]:
        return await self.select_raw(*e_root_to_sql(e_root), database_session=database_session)

    async def select_count(self, e_root: ERoot[T], database_session: Optional[AsyncSession] = None) -> int:
        return await self.select_raw(*e_root_to_sql(e_root, count=True), database_session=database_session)

    async def select_entity_list(self, e_root: ERoot[T], database_session: Optional[AsyncSession] = None) -> list[T]:
        answer = await self.select_e_root(e_root, database_session=database_session)
        result_type = e_root.entity[SelectedView]
        result = [dict_to_dataclass(line, result_type) for line in answer]

        return result

    async def select_single_entity(self, e_root: ERoot[T], database_session: Optional[AsyncSession] = None) -> T:
        answer = await self.select_e_root(e_root, database_session=database_session)
        if not answer:
            raise NotFound()

        if len(answer) > 1:
            raise MultipleFound()

        result_type = e_root.entity[SelectedView]
        result = dict_to_dataclass(answer[0], result_type)

        return result

