#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import sqlalchemy as sa


def extract_columns(expression) -> set[sa.Column]:
    columns = set()
    if isinstance(expression, sa.Case):
        columns |= extract_columns(expression.else_)
        for case_condition, case_value in expression.whens:
            columns |= extract_columns(case_condition)
            columns |= extract_columns(case_value)
    if isinstance(expression, sa.Column):
        columns.add(expression)
    if hasattr(expression, 'clauses'):
        for clause in expression.clauses:
            columns |= extract_columns(clause)
    if hasattr(expression, 'elements'):
        for elem in expression.elements:
            columns |= extract_columns(elem)
    if hasattr(expression, 'element'):
        columns |= extract_columns(expression.element)
    if hasattr(expression, 'left'):
        columns |= extract_columns(expression.left)
    if hasattr(expression, 'right'):
        columns |= extract_columns(expression.right)
    return columns


def get_column_to_dependency_columns(
        column_to_expression: dict[sa.Column, sa.ColumnElement | None]
) -> dict[sa.Column, set[sa.Column]]:
    column_to_involved: dict[sa.Column, set[sa.Column]] = {}
    for col, expr in column_to_expression.items():
        if expr is None:
            continue
        involved_columns = extract_columns(expr)
        container = column_to_involved.setdefault(col, set())
        container |= set(involved_columns)
    return column_to_involved
