#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import json
import subprocess
import traceback
from dataclasses import dataclass, field
from typing import Literal, Any

import mistune


@dataclass
class RunResult:
    stdout: str
    stderr: str
    returncode: int


def execute_code_block(code: str, language: str, timeout: float = 1) -> RunResult | subprocess.Popen:
    if language == 'python':
        process = subprocess.Popen(['python', '-c', code], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    elif language in ('shell', 'bash', 'sh'):
        process = subprocess.Popen(code, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
    else:
        raise ValueError('Unknown language')

    try:
        stdout, stderr = process.communicate(timeout=timeout)
        return RunResult(stdout=stdout, stderr=stderr, returncode=process.returncode)
    except subprocess.TimeoutExpired:
        return process
    except Exception:
        traceback.print_exc()
        raise


def terminate_running_processes(running_processes: list[subprocess.Popen], depth: int = 0) -> None:
    for process in running_processes:
        print(' ' * depth + f'terminate {process}')
        if process.poll() is None:
            process.terminate()
            try:
                process.wait(timeout=5)
            except subprocess.TimeoutExpired:
                process.kill()


def preview(value: Any, max_len: int = 64, separator: str = '...') -> str:
    result = repr(value)
    if len(result) > max_len:
        result = result[:max_len - len(separator)] + separator
    return result


@dataclass
class Check:
    code: str
    language: str
    expected_format: Literal['text', 'json'] | None = None
    expected: str | None = None

    def evaluate(self, depth: int = 0) -> subprocess.Popen | None:
        print(' ' * depth + f"Run {self.language} code: {preview(self.code)}")
        process_or_result = execute_code_block(self.code, language=self.language)
        if isinstance(process_or_result, subprocess.Popen):
            print(' ' * depth + f"Timeout {self.language} code, detach from it and kill later")
            assert self.expected is None, f"{self.code=} got unexpected timeout"
            return process_or_result
        result = process_or_result
        expected = self.expected
        actual = result.stdout
        print(' ' * depth + f'Got {preview(actual)}... ', end='')
        if self.expected_format == 'json':
            expected = json.loads(expected)
            actual = json.loads(actual)
        if expected:
            assert actual == expected, f"{self.code=} result does not match expected: {actual!r} != {expected!r}"
            print('OK')
        else:
            print('IGNORE')


@dataclass
class LogicBlock:
    name: str
    content: list['Check | LogicBlock'] = field(default_factory=list)

    def add(self, block: 'Check | LogicBlock') -> None:
        self.content.append(block)

    def evaluate(self, depth: int = 0) -> None:
        print(' ' * depth + f"Starting logic block {self.name!r}")
        running_processes = []
        try:
            for block in self.content:
                result = block.evaluate(depth + 1)
                if isinstance(result, subprocess.Popen):
                    running_processes.append(result)
        finally:
            terminate_running_processes(running_processes, depth)


def get_logic_block_from_md_file(file_path: str) -> LogicBlock:
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    markdown = mistune.create_markdown()
    html, parse_state = markdown.parse(text)
    html: str
    parse_state: mistune.BlockState

    return parse_tokens_as_logic(parse_state.tokens, file_path)


def parse_tokens_as_logic(tokens: list[dict], root_name: str) -> LogicBlock:
    root = current_block = LogicBlock(name=root_name)
    last_check = None
    for token in tokens:
        if token['type'] == 'block_code':
            language: str | None = token.get('attrs', {}).get('info', None)
            if language in ('python', 'shell', 'bash', 'sh'):
                last_check = Check(code=token['raw'], language=language)
                current_block.add(last_check)
            else:
                if last_check is not None:
                    last_check.expected_format = language
                    last_check.expected = token['raw']
                    last_check = None
        if token['type'] == 'heading':
            last_check = None
            assert len(children := token['children']) == 1, f'Expected one child, got {len(children)}'
            assert (child := children[0])['type'] == 'text', f'Expected one child of type text, got {child}'
            block_name = child['raw']
            target_parent = root
            target_depth = token.get('attrs', {}).get('level', 0)
            for _ in range(target_depth - 1):
                if not target_parent.content:
                    target_parent.add(LogicBlock(block_name))
                target_parent = target_parent.content[-1]
            target_parent.add(current_block := LogicBlock(block_name))

    return root

#
# def main():
#     block = build_test_tree_from_md_file('README_check.md')
#     try:
#         block.evaluate()
#     except Exception:
#         print("CHECK FAILED")
#         traceback.print_exc()
#         exit(1)
#     print("CHECK SUCCESS")
