#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import dataclasses
import functools
import logging
import re
import types
from decimal import Decimal
from enum import Enum
from types import UnionType
from typing import Callable, Any, get_args, get_origin, get_type_hints, Union, Literal

import frozendict
from dict_caster.extras import first
from http_tools import WrappedAnswerBody
from init_helpers import try_extract_type_notes, DataclassProtocol, raise_if, Jsonable, AnyType
from init_helpers.dict_to_dataclass import NoValue

from .any_of_schema import AnyOfSchema
from .all_of_schema import AllOfSchema
from .array_schema import ArraySchema
from .base_schema import BaseSchema
from .boolean_schema import BooleanSchema
from .empty_schema import EmptySchema
from .integer_schema import IntegerSchema
from .null_schema import NullSchema
from .number_schema import NumberSchema
from .object_schema import ObjectSchema
from .schema_format import StringSchemaFormat
from .string_schema import StringSchema
from .type_schema import TypeSchema
from ..example import BaseExample, Example
from ...server.endpoint_body import EndpointBody

# if TYPE_CHECKING:
#     from spec_schema import SpecSchema

logger = logging.getLogger(__name__)
#
# @functools.cache
# def build_spec_schema(schema: AnyType) -> 'SpecResource | NamedSpecSchema':
#     from spec_schema import SpecSchema
#     from named_spec_schema import NamedSpecSchema
#
#     if isinstance(schema, SpecSchema):
#         return schema
#
#     clean_type, notes = try_extract_type_notes(schema)
#
#     if examples := tuple(note for note in notes if isinstance(note, BaseExample)):
#         example = examples[0].value
#
#     if dataclasses.is_dataclass(schema):
#         return NamedSpecSchema(type_=schema, name=f"schema_{clean_type.__name__}")
#
#     return SpecSchema(schema)


class SchemaBuilder:

    @classmethod
    def build_schema(
            cls,
            schema: AnyType, default: Any = NoValue,
            description: str | None = None,
            examples: tuple[BaseExample, ...] | None = None,
            enum: tuple[Jsonable, ...] | None = None,
    ) -> BaseSchema:
        logger.debug('_build_schema: %s', schema)
        schema, notes = try_extract_type_notes(schema)
        examples = (examples or tuple()) + tuple(note for note in notes if isinstance(note, Example))
        descriptions = (description or tuple()) + tuple(note for note in notes if isinstance(note, str))
        pattern = first((note.pattern for note in notes if isinstance(note, re.Pattern)), none_if_empty=True)
        description = descriptions[0] if descriptions else None

        kwargs = {
            "description": description, "examples": examples, "default": default, "enum": enum, "pattern": pattern}

        if schema is Any:
            schema = dict | list | str | float | int | bool | None

        if origin_type := get_origin(schema):
            result = cls._get_generic_description(origin_type, get_args(schema), **kwargs)
        elif schema in (dict, frozendict, list, set, frozenset, tuple):
            result = cls._get_generic_description(schema, **kwargs)
        elif dataclasses.is_dataclass(schema):
            result = cls._get_dataclass_schema(schema, **kwargs)
        # elif isinstance(schema, type) and issubclass(schema, Enum):
        #     result = cls._get_primitive_description(schema)

        else:
            result = cls._get_primitive_description(schema, **kwargs)
        #
        # # AVOID MODIFICATION OF CACHED VALUES!!!
        # if isinstance(schema, type) and issubclass(schema, Enum):
        #     result = result | {'enum': [e.value for e in schema]}
        # if default is None or isinstance(default, (int, float, str, list, dict)):
        #     result = result | {'default': default}
        # if notes:
        #     if examples := tuple(note for note in notes if isinstance(note, Example)):
        #         result = result | {'example': examples[0].value}
        #     if descriptions := [note for note in notes if isinstance(note, str)]:
        #         result = result | {'description': descriptions[0]}
        # # if default == NoValue:  # TODO: think about it, should we place "required" inside schema properties or not?
        # #     description['required'] = True
        return result

    @classmethod
    def _get_primitive_description(
            cls, type_: type,
            description: str | None = None,
            examples: tuple[BaseExample, ...] | None = None,
            default: Any = NoValue,
            enum: tuple[Jsonable, ...] | None = None,
            pattern: str | None = None,
    ) -> TypeSchema:
        kwargs = {'description': description, 'examples': examples, 'default': default, 'enum': enum}
        if not enum and issubclass(type_, Enum):
            kwargs['enum'] = tuple(e.value for e in type_)
            kwargs['key'] = type_.__name__
        kwargs |= {'pattern': pattern} if pattern else {}

        for primitive_type, factory in cls.primitive_type_to_factory.items():
            if issubclass(type_, primitive_type):
                return factory(**kwargs)
                    # description=description, examples=examples, default=default, enum=enum, key=key, pattern=pattern)
        raise TypeError(f"Unknown type: {type_}")

    @classmethod
    def _get_dataclass_schema(
            cls,
            schema: type[DataclassProtocol],
            description: str | None = None,
            examples: tuple[BaseExample, ...] | None = None,
            default: Any = NoValue,
            enum: tuple[Jsonable, ...] | None = None,
            pattern: str | None = None,
    ) -> ObjectSchema:
        key_to_type_hint = get_type_hints(schema)
        key_to_schema = {}
        for field_ in dataclasses.fields(schema):
            if not field_.repr:
                continue

            try:
                field_schema = cls.build_schema(field_.type)
            except TypeError as e:
                field_schema = cls.build_schema(key_to_type_hint[field_.name])
            if field_.default is not dataclasses.MISSING:
                if not field_schema.has_default and not field_schema.key:
                    field_schema = dataclasses.replace(field_schema, default=field_.default)
                else:
                    field_schema = AllOfSchema([field_schema, EmptySchema(default=field_.default)])
            key_to_schema[field_.name] = field_schema
        key = None if issubclass(schema, WrappedAnswerBody | EndpointBody) else schema.__name__
        result = ObjectSchema(item_key_to_schema=key_to_schema, key=key)
        if description or examples or default is not NoValue:
            result = AllOfSchema([result, EmptySchema(description=description, examples=examples, default=default)])
        return result

    @classmethod
    def _get_generic_description(
            cls, origin_type: type | UnionType, type_args: tuple[type | UnionType, ...] | None = None,
            description: str | None = None,
            examples: tuple[BaseExample, ...] | None = None,
            default: Any = NoValue,
            enum: tuple[Jsonable, ...] | None = None,
            pattern: str | None = None,
    ) -> BaseSchema:
        kwargs = {'description': description, 'examples': examples, 'default': default}
        if origin_type in (Union, types.UnionType):
            raise_if(not type_args, TypeError(f'bad {origin_type.__name__} {type_args=}'))
            return AnyOfSchema(tuple(cls.build_schema(subtype) for subtype in type_args), **kwargs)
        elif origin_type in (dict, frozendict):
            raise_if(type_args and len(type_args) != 2, TypeError(f'bad {origin_type.__name__} {type_args=}'))
            return ObjectSchema(additional_items_schema=cls.build_schema(type_args[1]) if type_args else None, **kwargs)
        elif origin_type in (list, set, frozenset):
            raise_if(type_args and len(type_args) != 1, TypeError(f'bad {origin_type.__name__} {type_args=}'))
            return ArraySchema(items_schema=cls.build_schema(type_args[0]) if type_args else None,
                               is_uniqueness_required=origin_type != list, **kwargs)
        elif origin_type is tuple:
            items_types = {t for t in type_args if t != ellipsis}
            raise_if(len(items_types) > 1, TypeError(f'Inconsistent tuple got: {origin_type.__name__} {type_args=}'))
            return ArraySchema(items_schema=cls.build_schema(first(items_types)) if items_types else None, **kwargs)
            # TODO: limit items amount if ellipsis is missing
            # unbounded = ellipsis in type_args
        elif origin_type is Literal:
            raise_if(not type_args, TypeError(f'bad {origin_type.__name__} {type_args=}'))
            kwargs['enum'] = type_args
            arg_types = {type(arg) for arg in type_args}
            if len(arg_types) > 1:
                return AnyOfSchema(tuple(cls.build_schema(subtype) for subtype in arg_types), **kwargs)
            else:
                return cls.build_schema(first(arg_types), **kwargs)

        raise TypeError(f"Unknown origin_type: {origin_type}")

    # @classmethod
    # def _get_union_description(
    #         cls, type_args: tuple[type | UnionType] | None, description: str | None = None,
    #         examples: tuple[BaseExample, ...] | None = None, default: Any = NoValue
    # ) -> AnyOfSchema:
    #     raise_if(type_args is not None and len(type_args) != 2, TypeError(f'bad dict item type: {type_args}'))
    #     return AnyOfSchema(
    #         tuple(cls.build_schema(subtype) for subtype in type_args), description=description,
    #         examples=examples, default=default)
    #
    # @classmethod
    # def _get_dict_description(
    #         cls, type_args: tuple[type | UnionType] | None, description: str | None = None,
    #         examples: tuple[BaseExample, ...] | None = None, default: Any = NoValue
    # ) -> ObjectSchema:
    #     raise_if(type_args is not None and len(type_args) != 2, TypeError(f'bad dict item type: {type_args}'))
    #     return ObjectSchema(
    #         additional_items_schema=cls.build_schema(type_args[1]) if type_args else None, description=description,
    #         examples=examples, default=default)
    #
    # @classmethod
    # def _get_list_description(
    #         cls, type_args: tuple[type | UnionType] | None, description: str | None = None,
    #         examples: tuple[BaseExample, ...] | None = None, default: Any = NoValue
    # ) -> ArraySchema:
    #     raise_if(type_args is not None and len(type_args) != 1, TypeError(f'bad list item type: {type_args}'))
    #     return ArraySchema(
    #         items_schema=cls.build_schema(type_args[0]) if type_args else None,
    #         description=description, examples=examples, default=default)
    #
    # @classmethod
    # def _get_set_description(
    #         cls, type_args: tuple[type | UnionType] | None, description: str | None = None,
    #         examples: tuple[BaseExample, ...] | None = None, default: Any = NoValue
    # ) -> ArraySchema:
    #     raise_if(type_args is not None and len(type_args) != 1, TypeError(f'bad set item type: {type_args}'))
    #     return ArraySchema(
    #         items_schema=cls.build_schema(type_args[0]) if type_args else None, is_uniqueness_required=True,
    #         description=description, examples=examples, default=default)
    #
    # origin_type_to_factory: ClassVar[dict[type | UnionType, Callable[[tuple[type | UnionType, ...]], TypeSchema]]] = {
    #     Union: _get_union_description,
    #     types.UnionType: _get_union_description,
    #     dict: _get_dict_description,
    #     frozendict.frozendict: _get_dict_description,
    #     list: _get_list_description,
    #     set: _get_set_description,
    #     frozenset: _get_set_description,
    # }

    primitive_type_to_factory: dict[type, Callable[..., TypeSchema]] = {
        bool: BooleanSchema,  # MUST be before int
        int: IntegerSchema,
        float: NumberSchema,
        bytes: functools.partial(StringSchema, format_=StringSchemaFormat.binary),
        bytearray: functools.partial(StringSchema, format_=StringSchemaFormat.binary),
        str: StringSchema,
        Decimal: functools.partial(StringSchema, pattern=r'^\d*\.?\d*$'),
        type(None): NullSchema,
    }
