#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import io
import logging
import typing
import zipfile
from dataclasses import dataclass
from secrets import token_hex

from file_storage.exceptions import PathDoesNotExists

from .database import Database
from .extra import AnySession, assert_raise
from .link_keeper import LinkKeeper
from .entities.link_info import LinkInfo

logger = logging.getLogger(__name__)
LI = typing.TypeVar('LI', bound=LinkInfo)


class MaxZipFilesError(Exception):
    pass


class MaxZipSizeError(Exception):
    pass


class MemoryZipper(typing.Generic[LI]):
    @dataclass
    class Config:
        max_zip_files: int = 100
        max_zip_size: int = 100 * 1024 * 1024

    @dataclass
    class Context:
        link_keeper: LinkKeeper
        database: Database

    def __init__(self, context: Context, config: Config = None) -> None:
        self.config = config or self.Config()
        self.context = context

    async def zip_files(self, link_ids: list[str], session: AnySession = None) -> tuple[LI, bytes]:
        logger.info(f'Zip links: {link_ids=}')
        assert_raise(len(link_ids) <= self.config.max_zip_files,
                     MaxZipFilesError(f"Max files amount limit exceeded: {len(link_ids)}/{self.config.max_zip_files}"))

        async with self.context.database.ensure_session(session) as session:
            link_infos = await self.context.link_keeper.head_multiple(link_ids, session)
            assert_raise(not (missing_link_ids := {i.id for i in link_infos} - set(link_ids)),
                         PathDoesNotExists(f'Some files not found: {missing_link_ids}'))
            assert_raise((total_size := sum(i.file_info.size for i in link_infos)) < self.config.max_zip_size,
                         MaxZipSizeError(f"Max files size limit exceeded: {total_size}/{self.config.max_zip_size}"))

        zip_buffer = io.BytesIO()

        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, False) as zip_file:
            for info in link_infos:
                # If file name contains '/' zipfile lib thinking that it's subdirectories,
                # splitting name by it and creating archive with subdirectories
                # e.g. 'foo/bar.pdf' - in archive will be created folder foo and inside it file bar.
                file_payload = await self.context.link_keeper.read(info.file_info_key)
                zip_file.writestr(info.file_info_key.replace('/', '-'), file_payload,)

        zip_buffer.seek(0)
        file_content = zip_buffer.read()
        zip_extension = 'zip'
        info = FileMetadata.prepare(file_content, f'{token_hex(16)}.{zip_extension}', zip_extension)

        return info, file_content
