import json
from itertools import groupby
from typing import Union, List, Any

from sqlalchemy import Column, JSON
from sqlalchemy.orm import registry
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import FromClause
from sqlalchemy.sql import sqltypes
from sqlalchemy.dialects import postgresql

from ..utils.first import first
from ..utils.to_list import to_list


sqlalchemy_mapper_registry = registry()
sqlalchemy_mapper_registry.initial_values = {}


def register_initial_values(mapper_registry: registry, instance: Union[List[Any], Any], *instances: Any):
    mapper_registry.initial_values = getattr(mapper_registry, 'initial_values', {})
    instances = to_list(instance) + list(instances)
    def get_type(x): return type(x)
    type_to_instances = {key: list(items) for key, items in groupby(instances, key=get_type)}
    for type_, instances in type_to_instances.items():
        mapper_registry.initial_values.setdefault(type_, [])
        mapper_registry.initial_values[type_] += instances


class Values(FromClause):
    def __init__(self, args, columns: List[Column] = None):
        self.list = args
        self.columns_list = columns
        self.columns_amount = len(first(self.list))

    def _populate_column_collection(self):
        self._columns._populate_separate_keys(
            (f"column{i+1}", Column(f"column{i+1}")) for i in range(len(self.columns_list))
        )


def _render_literal_value(compiler, element: Union[int, str], column: Column):
    col_type = column.type
    if element is None:
        # TODO: use dialect based version
        result = f'NULL::{str(col_type)}'
    elif isinstance(element, (bool, int, float, str)):
        result = compiler.render_literal_value(element, col_type)
    elif isinstance(element, dict):
        element = json.dumps(element)
        sql_type = sqltypes.String()
        result = compiler.render_literal_value(element, sql_type)
    elif isinstance(element, list):
        sql_type = sqltypes.String()

        if col_type in (postgresql.JSON, postgresql.JSONB, JSON):
            element = '[' + ', '.join((str(e) for e in element)) + ']'
        else:
            element = '{' + ', '.join((str(e) for e in element)) + '}'
        result = compiler.render_literal_value(element, sql_type)
    else:
        raise NotImplementedError(f'unexpected type in VALUES: {type(element)}')

    return result


@compiles(Values)
def compile_values(element: Values, compiler, asfrom=False, **kw):
    columns = element.columns_list
    blocks = []
    for tup in element.list:
        block = ', '.join(_render_literal_value(compiler, elem, columns[i]) for i, elem in enumerate(tup))
        blocks.append(f"({block})")
    payload = ", ".join(blocks)
    v = "VALUES " + payload
    if asfrom:
        v = f"({v})"
    return v
