#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
from typing import Union, Type, Any, Callable, Awaitable

from init_helpers.dict_to_dataclass import NoValue
from sqlalchemy import ForeignKey, Column
from sqlalchemy.orm import relationship
from sqlalchemy.sql.type_api import TypeEngine


def sql_relation_field(remote_class, *, init: bool = True, use_list: bool = True,
                       order_by: Union[str, bool, list] = False, lazy: str = "select",
                       **relationship_kwargs) -> dataclasses.field:
    metadata = {"sa": relationship(remote_class, uselist=use_list, order_by=order_by, lazy=lazy, **relationship_kwargs)}
    return dataclasses.field(metadata=metadata, init=init, default_factory=list)


@dataclasses.dataclass(frozen=True)
class NoDefault:
    """
    It should be equal `dataclasses.MISSING` to be ignored by dict_to_dataclass
    But shouldn't be `dataclasses.MISSING`, cos sqlalchemy breaks =(
    """
    def __eq__(self, other):
        return dataclasses.MISSING == other

    def __repr__(self):
        return f'{type(self).__name__}()'

    def __call__(self, *args, **kwargs):
        return self

    def __bool__(self):
        return False


NoDefault = NoDefault()


def sql_field(type_: Union[Type[TypeEngine], TypeEngine, ForeignKey], foreign_key: ForeignKey = None,
              *, primary_key: bool = False, nullable: bool = True, unique: bool = False, index: bool = False,
              default: Any = None, server_default: Any = None, default_factory: Callable = None,
              insert_default: Callable[[...], Awaitable[Any]] = None, autoincrement: bool = None,
              identifier: bool = False) -> dataclasses.field:
    if primary_key:
        nullable = True

    column_args = [type_, foreign_key]
    column_kwargs = {"primary_key": primary_key} if primary_key else {"nullable": nullable, "unique": unique}

    if default is not None:
        column_kwargs['default'] = default
    if server_default is not None:
        column_kwargs['server_default'] = server_default
    if index:
        column_kwargs['index'] = index
    if autoincrement is not None:
        column_kwargs['autoincrement'] = autoincrement

    metadata = {"sa": Column(*column_args, **column_kwargs)}

    # extension attributes
    if insert_default:
        # TODO: check no "default" passed, think about using insert_default with None
        metadata['insert_default'] = insert_default
    metadata['identifier'] = identifier

    kwargs = {"metadata": metadata}
    if default_factory:
        kwargs['default_factory'] = default_factory

    # noinspection PyTypedDict
    kwargs['default'] = NoDefault
    if default:
        kwargs['default'] = default
    elif nullable or insert_default is not None or server_default is not None:
        kwargs['default'] = NoValue

    field = dataclasses.field(**kwargs)
    return field
