#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import dataclasses
import functools
import inspect
from dataclasses import dataclass, Field
from typing import ClassVar, Callable, Self, Any

from init_helpers.custom_json import ReprInDumps
from init_helpers.dict_to_dataclass import get_dataclass_field_name_to_field

TYPE_KEY = "_t"


@dataclass(frozen=True)
class TypedInDumps(ReprInDumps):
    shortcut: ClassVar[Callable[[], str]] = None

    @classmethod
    def get_subclasses(cls, recursive: bool = True, include_self: bool = True) -> set[type[Self]]:
        subclasses = set(cls.__subclasses__())
        children_subclasses = set()
        if recursive:
            for child in subclasses:
                children_subclasses |= child.get_subclasses()
        result = subclasses | children_subclasses
        if include_self:
            result.add(cls)
        return result

    @classmethod
    @functools.cache
    def get_name_to_subclass(cls) -> dict[str, type[Self]]:
        return {subclass.__name__: subclass for subclass in cls.get_subclasses()}

    @classmethod
    # @functools.cache
    def get_prefixed_name_to_non_abstract_subclass(cls) -> dict[str, type[Self]]:
        return {
            f"{subclass.get_prefix()}{name}": subclass
            for name, subclass in cls.get_name_to_subclass().items()
            if not inspect.isabstract(subclass)
        }

    @classmethod
    def get_subclass_by_name(cls, subclass_name: str) -> type[Self]:
        return cls.get_name_to_subclass()[subclass_name]

    def __repr_in_dumps__(self):
        if self.shortcut is not None:
            return self.shortcut()
        cls = type(self)
        return {TYPE_KEY: self.get_prefix() + cls.__name__} | self.as_dict()

    @classmethod
    def get_prefix(cls) -> str:
        return cls.__namespace__ + "::" if hasattr(cls, "__namespace__") else ''

    def as_dict(self) -> dict[str, Any]:
        result = {}
        dataclass_fields = get_dataclass_field_name_to_field(type(self), with_init_vars=False)
        for key in dataclass_fields:
            descriptor: Field = dataclass_fields[key]
            if not descriptor.repr:
                continue
            val = getattr(self, key)
            if val == descriptor.default:
                continue
            if descriptor.default_factory is not dataclasses.MISSING and val == descriptor.default_factory():
                continue
            result[key] = val
        return result

    def __repr__(self):
        return f'{type(self).__name__}({", ".join([f"{k}={v!r}" for k, v in self.as_dict().items()])})'

    __str__ = __repr__
