#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
from typing import Sequence, Iterable, TypeVar

import sqlalchemy.orm
import sqlalchemy.sql.elements
from sqlalchemy import Integer

from ..entity import Entity


def is_list_relation(relation: sqlalchemy.orm.RelationshipProperty) -> bool:
    return relation.uselist


def get_related_entity_type(relation: sqlalchemy.orm.RelationshipProperty) -> type[Entity]:
    return relation.argument


def get_relation_owner_entity_type(relation: sqlalchemy.orm.RelationshipProperty) -> type[Entity]:
    return relation.parent.entity


def get_remote_column(relation: sqlalchemy.orm.RelationshipProperty) -> sqlalchemy.Column:
    class_attribute: sqlalchemy.orm.attributes.InstrumentedAttribute = relation.class_attribute
    if not isinstance(class_attribute.expression, sqlalchemy.sql.elements.BinaryExpression):
        raise TypeError(f"failed to process relation: {relation}")

    binary_expression: sqlalchemy.sql.elements.BinaryExpression = class_attribute.expression
    remote_column: sqlalchemy.Column = binary_expression.right  # could it be left?
    return remote_column


def is_auto_incremented(column: sqlalchemy.Column) -> bool:
    """
    some logic ignored, may be done later
    https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.autoincrement
    """
    if column.primary_key and isinstance(column.type, Integer) and not column.foreign_keys:
        return bool(column.autoincrement)


def has_server_default(column: sqlalchemy.Column) -> bool:
    return column.server_default is not None or is_auto_incremented(column)


def has_client_default(column: sqlalchemy.Column) -> bool:
    return column.default is not None


def has_default(column: sqlalchemy.Column):
    return has_client_default(column) or has_server_default(column)


def get_column_pairs_from_relation(relation: sqlalchemy.orm.RelationshipProperty
                                   ) -> Sequence[tuple[sqlalchemy.ColumnElement, sqlalchemy.ColumnElement]]:
    return relation.synchronize_pairs


T = TypeVar("T")


def split_by_type(iterable: Iterable[T]) -> dict[type[T], list[T]]:
    obj_type_to_obj = {}
    for obj in iterable:
        obj_type_to_obj.setdefault(type(obj), []).append(obj)
    return obj_type_to_obj
