#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>, Alexander Medvedev <a.medvedev@abm-jsc.ru>
import logging
import typing
from contextlib import asynccontextmanager
from dataclasses import dataclass

from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_tools.database_connector.abstract_database_connector import AbstractDatabaseConnector

logger = logging.getLogger(__name__)
T = typing.TypeVar("T")


class WrappedSession:
    def __init__(self, session: AsyncSession):
        self._session: AsyncSession = session

    def __getattr__(self, item):
        print('item', item)
        if item != "_session":
            return object.__getattribute__(self._session, item)
        else:
            return self._session

    async def scalars(self, query: Select[T]) -> list[T]:
        result = await self._session.execute(query)
        return list(result.scalars())

    async def scalar(self, query: Select[T]) -> T:
        result = await self._session.execute(query)
        return result.scalar_one()

    async def scalar_or_none(self, query: Select[T]) -> T:
        result = await self._session.execute(query)
        return result.scalar_one_or_none()

    async def add_and_flush(self, instance: object) -> None:
        self._session.add(instance)
        await self._session.flush([instance])

    async def delete_and_flush(self, instance: object) -> None:
        logger.debug(f'delete_and_flush {instance=}')
        await self._session.delete(instance)
        await self._session.flush([instance])


class Database:
    @dataclass
    class Context:
        database_connector: AbstractDatabaseConnector

    def __init__(self, context: Context) -> None:
        self._context = context

    @asynccontextmanager
    async def ensure_session(self, unsafe_session: AsyncSession | None = None) -> WrappedSession:
        if unsafe_session is not None:
            yield WrappedSession(unsafe_session) if isinstance(unsafe_session, AsyncSession) else unsafe_session
        else:
            async with self._context.database_connector.get_session() as session:
                yield WrappedSession(session)
                await session.commit()
