#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>, Alexander Medvedev <a.medvedev@abm-jsc.ru>

from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import TypeAlias, Any, Iterable

from dict_caster.extras import to_list
from sqlalchemy import select, delete, Select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy_tools.database_connector.abstract_database_connector import AbstractDatabaseConnector

from .tools.get_timestamp import get_current_timestamp_in_msec
from .entities.file_metadata import FileMetadata

Filter: TypeAlias = list[BinaryExpression] | BinaryExpression | bool


class FileMetadataDoesNotExists(KeyError):
    pass


class FileDatabaseFacade:
    @dataclass
    class Context:
        database_connector: AbstractDatabaseConnector

    def __init__(self, context: Context) -> None:
        self._context = context

    @asynccontextmanager
    async def ensure_session(self, unsafe_session: AsyncSession | None = None) -> AsyncSession:
        if unsafe_session is not None:
            yield unsafe_session
        else:
            async with self._context.database_connector.get_session() as session:
                yield session
                await session.commit()

    async def get_info(self,
                       session: AsyncSession,
                       storage_file_name: str,
                       raise_if_not_exist: bool = True) -> FileMetadata | None:
        query = self._prepare_file_metadata_select_query(storage_file_name)
        file_metadata = (await session.execute(query)).scalar()

        if not file_metadata and raise_if_not_exist:
            raise FileMetadataDoesNotExists(f'Not found file with storage_file_name: {storage_file_name!r}')

        return file_metadata

    async def get_file_metadatas_by_file_names(self,
                                               session: AsyncSession,
                                               storage_file_names: Iterable[str],
                                               ) -> list[FileMetadata]:
        query = self._prepare_file_metadata_select_query(storage_file_names)
        return list((await session.execute(query)).scalars())

    async def get_file_metadatas_by_timerange(self,
                                              session: AsyncSession,
                                              from_datetime: float | None = None,
                                              to_datetime: float | None = None,
                                              limit: int | None = None,
                                              offset: int | None = None,
                                              ) -> list[FileMetadata]:
        query = self._prepare_file_metadata_select_query(
            from_datetime=from_datetime,
            to_datetime=to_datetime,
            limit=limit,
            offset=offset,
        )
        return list((await session.execute(query)).scalars())

    @staticmethod
    async def get_outdated_file_metadatas(session: AsyncSession,
                                          limit: int | None = None,
                                          offset: int | None = None,
                                          ) -> list[FileMetadata]:
        query = select(
            FileMetadata
        ).options(
            joinedload(FileMetadata.thumbnail)
        ).where(
            FileMetadata.outdated_at is not None
        ).order_by(
            FileMetadata.uploaded_at
        ).limit(limit).offset(offset)
        return list((await session.execute(query)).scalars())

    @staticmethod
    async def set_outdated_at(session: AsyncSession, storage_file_names: list[str]) -> None:
        query = update(
            FileMetadata
        ).where(
            FileMetadata.storage_file_name.in_(storage_file_names)
        ).values(
            {FileMetadata.outdated_at: get_current_timestamp_in_msec()}
        )
        await session.execute(query)

    @staticmethod
    async def add(session: AsyncSession, file_metadata: FileMetadata) -> None:
        session.add(file_metadata)
        await session.flush()

    @staticmethod
    async def delete(session: AsyncSession, storage_file_name: str) -> bool:
        query = delete(FileMetadata).where(FileMetadata.storage_file_name == storage_file_name)
        return bool((await session.execute(query)).rowcount)

    @staticmethod
    async def delete_file_metadatas(session: AsyncSession, storage_file_names: list[str] | str) -> list[FileMetadata]:
        query = delete(
            FileMetadata
        ).where(
            FileMetadata.storage_file_name.in_(storage_file_names)
        ).execution_options(
            synchronize_session=False
        ).returning(
            FileMetadata
        )
        return list((await session.execute(query)).scalars())

    def _prepare_file_metadata_select_query(self,
                                            storage_file_names: str | Iterable[str] | None = None,
                                            from_datetime: float | None = None,
                                            to_datetime: float | None = None,
                                            limit: int | None = None,
                                            offset: int | None = None,
                                            ) -> Select[Any]:
        filters = self._prepare_file_metadata_filters(storage_file_names, from_datetime, to_datetime)

        return select(
            FileMetadata
        ).options(
            joinedload(FileMetadata.thumbnail)
        ).where(
            *to_list(filters)
        ).order_by(
            FileMetadata.uploaded_at
        ).limit(limit).offset(offset)

    @staticmethod
    def _prepare_file_metadata_filters(storage_file_names: str | Iterable[str] | None = None,
                                       from_datetime: float | None = None,
                                       to_datetime: float | None = None) -> Filter:
        filters = []

        if storage_file_names:
            filters.append(FileMetadata.storage_file_name.in_(to_list(storage_file_names)))

        if from_datetime:
            filters.append(FileMetadata.uploaded_at > from_datetime)

        if to_datetime:
            filters.append(FileMetadata.uploaded_at < to_datetime)

        filters.append(FileMetadata.outdated_at is None)

        return filters
