#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import asyncio
import shlex

from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Any

import mistune
import mistune.directives

from run_markdown.json_comparator import JsonComparator
from run_markdown.text_comparator import TextComparator


@dataclass
class RunResult:
    stdout: str = ''
    process: asyncio.subprocess.Process | None = None


async def execute_code_block(code: str, language: str, path: Path, timeout: float = 1) -> RunResult:
    if language == 'python':
        # process_args = ['script', '-q', '-c', 'python', '-c', code]
        process_args = ['script', '-q', '-c', f"python -c {shlex.quote(code)}"]
        # process_args = ['python', '-c', code]
    elif language in ('shell', 'bash', 'sh'):
        process_args = ['script', '-q', '-c', code]
    else:
        raise ValueError('Unknown language')

    process: asyncio.subprocess.Process = await asyncio.create_subprocess_exec(
        *process_args, cwd=str(path),
        stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)

    result = RunResult()
    try:
        # await asyncio.wait_for(process.communicate(code.encode()), timeout=timeout)
        await asyncio.wait_for(process.wait(), timeout=timeout)
        # print('proc finished')
    except asyncio.TimeoutError:
        # print('proc timout')
        result.process = process
    stdout = b''
    try:
        while True:
            red = await asyncio.wait_for(process.stdout.readline(), timeout=0.1)
            # print(f'readline {red=}')
            if red:
                stdout += red
            else:
                # print(f'break')
                break
    except asyncio.TimeoutError:
        # print('readline timeout')
        pass

    try:
        while True:
            red = await asyncio.wait_for(process.stdout.read(1), timeout=0.1)
            # print(f'read {red=}')
            if red:
                stdout += red
            else:
                # print(f'break')
                break
    except asyncio.TimeoutError:
        # print('read timeout')
        pass

    try:
        body = stdout.decode()
    except UnicodeDecodeError as e:
        body = stdout[:e.start]
    result.stdout = body
    return result


async def terminate_running_processes(running_processes: list[asyncio.subprocess.Process], depth: int = 0) -> None:
    for process in running_processes:
        print(' ' * depth + f'terminate {process}')
        try:
            process.terminate()
            try:
                await asyncio.wait_for(process.wait(), timeout=1)
            except asyncio.TimeoutError:
                process.kill()
        except ProcessLookupError:
            pass


def preview(value: Any, max_len: int = 64, separator: str = '...') -> str:
    result = repr(value)
    if len(result) > max_len:
        result = result[:max_len - len(separator)] + separator
    return result


@dataclass
class Check:
    code: str
    language: str
    expected_format: Literal['text', 'json'] | None = None
    expected: str | None = None
    path: str | None = None

    async def evaluate(self, tmp_dir: Path, depth: int = 0, timeout: float = 1) -> RunResult:
        print(' ' * depth + f"Run {self.language} code: {preview(self.code)}")
        if self.path:
            target_path = tmp_dir / self.path.removeprefix("/")
            target_path.write_text(self.code)
        return await execute_code_block(self.code, path=tmp_dir, language=self.language, timeout=timeout)

    def check_result(self, result: RunResult, depth: int = 0) -> None:
        if result.process is not None:
            print(' ' * depth + f"Timeout {self.language} code, detach from it and kill later")
        expected = self.expected
        actual = result.stdout
        print(' ' * depth + f'Got {preview(actual)}... ', end='')
        if expected:
            if self.expected_format == 'json':
                comparator = JsonComparator(expected)
            else:
                comparator = TextComparator(expected)
            comparison_result = comparator.compare(actual)
            if comparison_result.exception is not None:
                print('ERROR')
                print('ERROR DETAILS:')
                comparison_result.print()
                print('RAW DATA:')
                print(repr(actual))
                raise comparison_result.exception
            else:
                print('OK')
        else:
            print('IGNORE')


@dataclass
class LogicBlock:
    name: str
    content: list['Check | LogicBlock'] = field(default_factory=list)

    @property
    def child_blocks(self) -> list['LogicBlock']:
        return [c for c in self.content if isinstance(c, LogicBlock)]

    def add(self, block: 'Check | LogicBlock') -> None:
        self.content.append(block)

    async def evaluate(self, tmp_dir: Path, depth: int = 0, timeout: float = 1) -> None:
        print(' ' * depth + f"Starting logic block {self.name!r}")
        running_processes = []
        try:
            for block in self.content:
                result = await block.evaluate(tmp_dir=tmp_dir, depth=depth + 1, timeout=timeout)
                if result is not None:
                    if result.process is not None:
                        running_processes.append(result.process)
                    assert isinstance(block, Check)
                    block.check_result(result, depth + 1)
        finally:
            await terminate_running_processes(running_processes, depth)


def get_logic_block_from_md_file(file_path: str) -> LogicBlock:
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    markdown = mistune.create_markdown(plugins=[mistune.directives.FencedDirective([mistune.directives.Admonition()])])
    html, parse_state = markdown.parse(text)
    html: str
    parse_state: mistune.BlockState

    return parse_tokens_as_logic(parse_state.tokens, file_path)


FILE_PATH_PREFIX = 'file://'
FILE_PATH_POSTFIX = ':'


def parse_tokens_as_logic(tokens: list[dict], root_name: str) -> LogicBlock:
    root = current_block = LogicBlock(name=root_name)
    last_check = None
    path = None
    for token in tokens:
        if token['type'] == 'paragraph' and (last_child := token.get('children', [])[-1])['type'] == 'strong':
            if len(children := last_child.get('children', [])) == 1 and (sub_child := children[0])['type'] == 'text':
                if (text := sub_child['raw']).startswith(FILE_PATH_PREFIX) and text.endswith(FILE_PATH_POSTFIX):
                    path = text.removeprefix(FILE_PATH_PREFIX).removesuffix(FILE_PATH_POSTFIX)
        if token['type'] == 'block_code':
            language: str | None = token.get('attrs', {}).get('info', None)
            if language in ('python', 'shell', 'bash', 'sh'):
                last_check = Check(code=token['raw'], language=language, path=path)
                path = None
                current_block.add(last_check)
            else:
                if last_check is not None:
                    last_check.expected_format = language
                    last_check.expected = token['raw']
                    last_check = None
        if token['type'] == 'heading':
            last_check = None
            assert len(children := token['children']) == 1, f'Expected one child, got {len(children)}'
            assert (child := children[0])['type'] == 'text', f'Expected one child of type text, got {child}'
            block_name = child['raw']
            target_parent = root
            target_depth = token.get('attrs', {}).get('level', 0)
            for _ in range(target_depth - 1):
                if not target_parent.child_blocks:
                    target_parent.add(LogicBlock(block_name))
                target_parent = target_parent.child_blocks[-1]
            target_parent.add(current_block := LogicBlock(block_name))

    return root
