#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
import itertools
import string
from typing import Optional, Any

from entity_read.sql.context import SqlContext
from .atoms import Column
from .atoms.node import SQL
from .atoms.order import Order
from entity_read.sql.atoms.abstract import Selectable


@dataclasses.dataclass
class LowerSelector:
    table_name: str
    schema_name: str = ''
    result_name_to_selectable: dict[str, Selectable[Any]] = dataclasses.field(default_factory=dict)
    solo_result: Selectable[Any] | None  = None
    where: list[Selectable[bool]] = dataclasses.field(default_factory=list)
    is_hidden: bool = False
    is_scalar: bool = False
    order: list[Order] = dataclasses.field(default_factory=list)
    limit: Optional[int] = None
    offset: Optional[int] = None

    local_column_to_remote: dict[str, str] = dataclasses.field(default_factory=dict)
    step_to_sub_selector: dict[str, 'LowerSelector'] = dataclasses.field(default_factory=dict)
    my_step: str | None = None
    rename: str | None = None

    def __post_init__(self):
        if self.solo_result and self.result_name_to_selectable:
            raise ValueError(f"only one of {self.solo_result=} and {self.result_name_to_selectable} should be filled")

    @property
    def full_table_name(self) -> str:
        return f'{self.schema_name}.{self.table_name}' if self.schema_name else self.table_name

    def attach_selector(self: 'LowerSelector', child_name: str, child_selector: 'LowerSelector',
                        child_column_name_to_parent_column_name: dict[str, str]):
        if child_name in self.step_to_sub_selector:
            raise ValueError(f"attempt to attach selector with name: {child_name!r}, but it is already used")
        self.step_to_sub_selector[child_name] = child_selector
        assert not child_selector.local_column_to_remote
        child_selector.local_column_to_remote = child_column_name_to_parent_column_name
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

    def get_subquery_join_conditions(self, my_name: str, parent_name: str) -> list[SQL]:
        return [f'{my_name}."{local}" = {parent_name}."{remote}"' for local, remote in self.local_column_to_remote.items()]

    def get_local_relation_columns(self, my_name: str) -> dict[str, str]:
        return {local: f'{my_name}."{local}"' for local in self.local_column_to_remote}

    def get_conditions(self, context: SqlContext) -> list[SQL]:
        return [condition.to_sql(context) for condition in self.where]

    def get_order(self, context: SqlContext) -> str:
        order_sql = str.join(", ", [order.to_order_sql(context) for order in self.order])
        return "ORDER BY " + order_sql if order_sql else ''

    def get_limit_offset_sql(self) -> str:
        parts: list[str] = []
        if self.limit is not None:
            parts.append(f'LIMIT {self.limit}')
        if self.offset is not None:
            parts.append(f'OFFSET {self.offset}')
        result = '\n' + str.join(' ', parts) if parts else ''
        return result

    def to_sql(self, context: SqlContext, key_to_expression: dict = None) -> str:
        return _generate_sql(context.self_table_name, context.self_table_name, self, context.key_to_param, agg=False)

    def get_name(self) -> str:
        if self.rename:
            return self.rename

        parts = []
        select = ','.join([f'{k}={v.get_name()}' for k, v in self.result_name_to_selectable.items()])
        if select:
            parts.append(f'select={select}')

        where = ",".join([expression.get_name() for expression in self.where])
        if where:
            parts.append(f'where={where}')

        order = ",".join([expression.get_name() for expression in self.order])
        if order:
            parts.append(f'order={order}')

        if self.my_step:
            result = f"{self.my_step}.subquery({','.join(parts)})"
        else:
            on = ','.join([f'{key}={value}' for key, value in self.local_column_to_remote.items()])
            parts = [self.full_table_name, on] + parts
            result = f"Query({self.full_table_name}{on}{','.join(parts)})"

        return result


def _ensure_label(alias: str, *, alias_max_len: int = 63, separator: str = "HIDDEN") -> str:
    if len(alias) < alias_max_len:
        return alias

    part_size = int((alias_max_len - len(separator)) / 2)
    return f'{alias[:part_size]}{separator}{alias[-part_size:]}'


JSON_LIST__NAME = 'json_list'
JSON_OBJECT__NAME = 'json_object'
MAX_FUNC_ARGS = 50  # Postgresql limitation


def _generate_sql(
        parent_alias: str, outer_alias: str, selector: LowerSelector, key_to_param: Optional[dict] = None,
        agg: bool = True
) -> SQL:
    required__name_to_sql = {}
    select__name_to_sql = {}
    is_object = selector.solo_result is None
    if selector.solo_result:
        assert all(sub.is_hidden for sub in selector.step_to_sub_selector.values()), \
            "cannot combine solo_result with non hidden sub_selector"

    current_table_with_subqueries_alias = _ensure_label(f'inner_{outer_alias}')
    current_table_alias = _ensure_label(f'{selector.table_name}_alias')

    inner_context = SqlContext(current_table_alias, key_to_param)
    outer_context = SqlContext(current_table_with_subqueries_alias, key_to_param)

    for node in itertools.chain(selector.where, selector.order, selector.result_name_to_selectable.values()):
        required__name_to_sql.update({column.get_name(): column.to_sql(inner_context) for column in node.requires()})

    for sub_selector in selector.step_to_sub_selector.values():
        columns = [Column(type_=None, key=column_name) for column_name in sub_selector.local_column_to_remote.values()]
        required__name_to_sql.update({column.get_name(): column.to_sql(inner_context) for column in columns})

    for result_name, selectable in selector.result_name_to_selectable.items():
        select__name_to_sql[result_name] = selectable.to_sql(outer_context)

    joins = []
    for i, (step, child_selector) in enumerate(selector.step_to_sub_selector.items()):
        child_selector: LowerSelector
        child_alias: str = _ensure_label(f'{outer_alias}_{step}')
        child_column_name: str = _ensure_label(step)
        if not child_selector.is_hidden:
            select__name_to_sql[f"{child_column_name}"] = child_column_name
        if child_selector.solo_result:
            required__name_to_sql[child_column_name] = f"{child_alias}.{JSON_OBJECT__NAME}"
        elif child_selector.is_scalar:
            required__name_to_sql[child_column_name] = f"coalesce({child_alias}.{JSON_OBJECT__NAME}, '{{}}'::jsonb)"
        else:
            required__name_to_sql[child_column_name] = f"coalesce({child_alias}.{JSON_LIST__NAME}, '[]'::jsonb)"
        joins.append(
            f"\nLEFT JOIN LATERAL ({_generate_sql(current_table_alias, child_alias, child_selector, key_to_param)}) "
            f"as {child_alias} ON TRUE"
        )

    required__name_to_sql.update(selector.get_local_relation_columns(current_table_alias))
    required__name_to_sql = {
        k: (v + " as " + k if k in selector.step_to_sub_selector else v)
        for k, v in required__name_to_sql.items()
    }
    where_conditions = " AND ".join(
        selector.get_subquery_join_conditions(current_table_with_subqueries_alias, parent_alias)
        + selector.get_conditions(outer_context))
    where_sql = "WHERE " + where_conditions if where_conditions else ""

    from_ = f"FROM {selector.full_table_name} as {current_table_with_subqueries_alias}" if not joins else f"""
            FROM (
            SELECT {", ".join(required__name_to_sql.values())}
            FROM {selector.full_table_name} as {current_table_alias} {"".join(joins)}
        ) as {current_table_with_subqueries_alias}
    """

    inner_query = f"""
        SELECT {_prepare_json_object_sql(select__name_to_sql) if is_object else selector.solo_result.to_sql(outer_context)} as {JSON_OBJECT__NAME}
        {from_}
        {where_sql} {selector.get_order(outer_context)} {selector.get_limit_offset_sql()}
    """
    if not agg or selector.is_scalar or selector.solo_result:
        return inner_query
    outer_query = f"""
    SELECT jsonb_agg({JSON_OBJECT__NAME}) as {JSON_LIST__NAME}
    FROM (
        {inner_query}
    ) as jsoned
    """
    return outer_query
    # return f"""
    # SELECT jsonb_agg({JSON_OBJECT__NAME}) as {JSON_LIST__NAME}
    # FROM (
    #     SELECT {_prepare_json_object_sql(select__name_to_sql)} as {JSON_OBJECT__NAME}
    #     FROM (
    #         SELECT {", ".join(required__name_to_sql.values())}
    #         FROM {selector.full_table_name} as {current_table_alias} {"".join(joins)}
    #     ) as {current_table_with_subqueries_alias}
    #     {where_sql} {selector.get_order(outer_context)} {selector.get_limit_offset_sql()}
    # ) as jsoned
    # """


ALLOWED_CHARS = string.ascii_letters + string.digits + string.punctuation + " "

def escape(key: str) -> str:
    chars = []
    for char in key:
        if char == "'":
            chars.append("''")
        elif char in ALLOWED_CHARS:
            chars.append(char)

    return ''.join(chars)

def _prepare_json_object_sql(select__name_to_sql: dict[str, str]) -> SQL:
    jsonb_objects = []
    name_to_sql_pairs = list(select__name_to_sql.items())
    for i in range(0, len(select__name_to_sql), MAX_FUNC_ARGS):
        args = ", ".join(f"'{escape(name)}', {value}" for name, value in name_to_sql_pairs[i:i+MAX_FUNC_ARGS])
        jsonb_objects.append(f"jsonb_build_object({args})")
    jsonb_objects = jsonb_objects or ["jsonb_build_object()"]
    result = " || ".join(jsonb_objects)
    return result

# <- generate_sql
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# -> Utils


def wrap_json_sql_as_list(sql: str) -> str:
    return f"SELECT coalesce({JSON_LIST__NAME}, '[]'::jsonb) FROM ({sql}) as _main_"


def wrap_json_sql_as_object(sql: str) -> str:
    return f"SELECT coalesce({JSON_OBJECT__NAME}, '{{}}'::jsonb) FROM ({sql}) as _main_"


def wrap_sql_as_json_list(sql: str) -> str:
    return f"SELECT coalesce(jsonb_agg({JSON_OBJECT__NAME}), '[]'::jsonb) FROM ({sql}) as _main_"


def generate_sql_with_args_from_selector(selector: LowerSelector) -> tuple[str, dict]:
    sql_parameter_key_to_value = {}
    sql = _generate_sql("", "", selector, sql_parameter_key_to_value)
    if selector.solo_result is not None:
        result = wrap_sql_as_json_list(sql)
    else:
        result = wrap_json_sql_as_object(sql) if selector.is_scalar else wrap_json_sql_as_list(sql)
    # result = wrap_sql_as_object(sql) if selector.is_scalar else wrap_sql_as_list(sql)
    return result, sql_parameter_key_to_value
