from __future__ import annotations
import inspect
import json
import operator
import sys
import traceback
from collections.abc import Callable, Iterable
from typing import Any, Self, final
from . import LOGGER, TooFewChildren, TooManyChildren, EdgeValueError, NodeValueError, NodeNotFountError
LOGGER = LOGGER.getChild('abc')
__all__ = ['LGM', 'LogicGroup', 'SkipContextsBlock', 'LogicExpression', 'ExpressionCollection', 'LogicNode', 'ActionNode', 'ELSE_CONDITION']
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ConditionElse(object):
"""Represents an else condition in decision trees."""
def __str__(self):
return ""
ELSE_CONDITION = NO_CONDITION = ConditionElse()
class LogicGroupManager(metaclass=Singleton):
"""Manager for LogicGroup instances and runtime expression context.
Handles caching and reuse of `LogicGroup` objects and manages runtime
stacks used while building and evaluating decision graphs.
"""
def __init__(self):
"""Initialize the manager and its runtime state.
This constructor creates the internal lists and flags used to track the
active logic groups, active expression nodes, breakpoint nodes (early
exits), pending connections, and shelved snapshots.
Attributes (initialized):
_cache (dict): Name -> LogicGroup cache.
_active_groups (list[LogicGroup]): Stack of active logic groups.
_active_nodes (list[LogicNode]): Stack of active expression nodes.
_breakpoint_nodes (list[ActionNode]): Breakpoint action nodes recorded
during inspection-mode breaks.
_pending_connection_nodes (list[ActionNode]): Breakpoint nodes that
must be connected to the next entered expression node.
_shelved_state (list[dict]): Stack of saved snapshots for shelve/unshelve.
inspection_mode (bool): When True, expression entry checks are bypassed.
vigilant_mode (bool): If True, perform stricter validation when building graphs.
"""
# Dictionary to store cached LogicGroup instances
self._cache = {}
# Cursor to track the currently active LogicGroups
self._active_groups: list[LogicGroup] = []
self._active_nodes: list[LogicNode] = []
self._breakpoint_nodes: list[ActionNode] = [] # action nodes, usually NoAction() nodes, marked as an early-exit (breakpoint) of a logic group
self._pending_connection_nodes: list[ActionNode] = [] # for those breakpoint-nodes, they will be activated when the corresponding logic group is finalized.
self._shelved_state = [] # shelve state to support temporally initialize a separate node-graph
self.inspection_mode = False
self.vigilant_mode = False
def __call__(self, name: str, cls: type[LogicGroup], **kwargs) -> LogicGroup:
"""Return a cached LogicGroup by name or create and cache a new one.
Args:
name: Logical name of the LogicGroup.
cls: Class to instantiate if no cached group exists.
**kwargs: Passed to the LogicGroup constructor.
Returns:
LogicGroup: The cached or newly created LogicGroup instance.
Notes:
This method makes the manager callable and is used by
``LogicGroupMeta`` to centralize instance caching.
"""
if name in self._cache:
return self._cache[name]
# Create a new instance and add it to the cache
logic_group = cls(name=name, **kwargs)
self._cache[name] = logic_group
return logic_group
def __contains__(self, name: str) -> bool:
"""Return True if a LogicGroup with `name` is cached.
Args:
name: Name of the LogicGroup.
Returns:
bool: True when present in the cache.
"""
return name in self._cache
def __getitem__(self, name: str) -> LogicGroup:
"""Get a cached LogicGroup by name.
Args:
name: Name of the LogicGroup to retrieve.
Returns:
LogicGroup: The cached group.
Raises:
KeyError: If the name is not present in the cache.
"""
return self._cache[name]
def __setitem__(self, name: str, value: LogicGroup):
"""Set or replace a cached LogicGroup.
Args:
name: Name under which to cache the LogicGroup.
value: The LogicGroup instance to cache.
"""
self._cache[name] = value
def enter_logic_group(self, logic_group: LogicGroup):
"""Mark a `LogicGroup` as active by pushing it onto the active stack.
Args:
logic_group: The LogicGroup entering context.
Side effects:
Appends `logic_group` to ``self._active_groups`` so it becomes
available from ``active_logic_group``.
"""
self._active_groups.append(logic_group)
def exit_logic_group(self, logic_group: LogicGroup):
"""Handle the exit of a LogicGroup and ensure subsequent groups also exit.
:param logic_group: The LogicGroup exiting the context.
"""
if not self._active_groups or self._active_groups[-1] is not logic_group:
raise ValueError("The LogicGroup is not currently active.")
self._active_groups.pop(-1)
for node in self._breakpoint_nodes:
if getattr(node, 'break_from') is logic_group:
self._pending_connection_nodes.append(node)
def enter_expression(self, node: LogicNode):
"""Register that an expression node has become active.
:param node: The LogicNode being entered.
"""
# If the node itself is an ActionNode, log an error (shouldn't enter a
# `with` block for an ActionNode). The old code checked `self` which is
# the manager and always false; check the `node` instead.
if isinstance(node, ActionNode):
LOGGER.error('Enter the with code block of an ActionNode rejected. Check is this intentional?')
if self._pending_connection_nodes:
from .node import NoAction
for _exit_node in self._pending_connection_nodes:
if isinstance(_exit_node, NoAction):
if (parent := _exit_node.parent) is None:
raise NodeNotFountError('ActionNode must have a parent node!')
parent.replace(original_node=_exit_node, new_node=node)
else:
_exit_node.edges.append(NO_CONDITION)
_exit_node.nodes[NO_CONDITION] = node
self._pending_connection_nodes.clear()
if (active_node := self.active_expression) is not None:
active_node: LogicNode = active_node
active_node.subordinates.append(node)
self._active_nodes.append(node)
def exit_expression(self, node: LogicNode):
"""Unregister an expression node when its context exits.
Args:
node: The expression node being exited.
Raises:
ValueError: If `node` is not the current active expression (i.e., exit
order violated).
"""
if not self._active_nodes or self._active_nodes[-1] is not node:
raise ValueError(f"The {node} is not currently active.")
self._active_nodes.pop(-1)
def shelve(self):
"""Temporarily save and clear runtime node/breakpoint/pending state.
The current ``_active_nodes``, ``_breakpoint_nodes`` and
``_pending_connection_nodes`` lists are copied into a snapshot that is
appended to ``_shelved_state``. Those lists are then cleared so a
separate, isolated evaluation or inspection context can be created.
Returns:
dict: The saved snapshot containing keys 'active_nodes',
'breakpoint_nodes' and 'pending_connection_nodes'.
"""
shelved_state = dict(
active_nodes=self._active_nodes.copy(),
breakpoint_nodes=self._breakpoint_nodes.copy(),
pending_connection_nodes=self._pending_connection_nodes.copy()
)
self._active_nodes.clear()
self._breakpoint_nodes.clear()
self._pending_connection_nodes.clear()
self._shelved_state.append(shelved_state)
return shelved_state
def unshelve(self, reset_active: bool = True, reset_breakpoints: bool = True, reset_pending: bool = True):
"""Restore the most recent shelved snapshot.
Args:
reset_active: If True, clear ``_active_nodes`` before restoring the snapshot.
reset_breakpoints: If True, clear ``_breakpoint_nodes`` before restoring.
reset_pending: If True, clear ``_pending_connection_nodes`` before restoring.
Returns:
dict: The restored snapshot (same shape as returned by `shelve`).
Raises:
IndexError: If there is no shelved snapshot to unshelve.
"""
shelved_state = self._shelved_state.pop(-1)
if reset_active:
self._active_nodes.clear()
if reset_breakpoints:
self._breakpoint_nodes.clear()
if reset_pending:
self._pending_connection_nodes.clear()
self._active_nodes[:0] = shelved_state['active_nodes']
self._breakpoint_nodes[:0] = shelved_state['breakpoint_nodes']
self._pending_connection_nodes[:0] = shelved_state['pending_connection_nodes']
return shelved_state
def clear(self):
"""Clear cached LogicGroup instances and reset active stacks.
Use this to reset the manager to an empty state. Does not touch
``_shelved_state``.
"""
self._cache.clear()
self._active_groups.clear()
self._active_nodes.clear()
@property
def active_logic_group(self) -> LogicGroup | None:
"""Return the currently active LogicGroup (top of active stack) or None.
Returns:
LogicGroup | None: The active LogicGroup or None if no groups active.
"""
if self._active_groups:
return self._active_groups[-1]
return None
@property
def active_expression(self) -> LogicNode | None:
"""Return the currently active expression node (top of active nodes) or None.
Returns:
LogicNode | None: The active expression node or None if none active.
"""
if self._active_nodes:
return self._active_nodes[-1]
return None
LGM = LogicGroupManager()
class LogicGroupMeta(type):
"""
A metaclass for LogicGroup that manages caching of instances.
"""
_registry_ = {}
def __new__(cls, name, bases, dct):
new_class = super().__new__(cls, name, bases, dct)
cls._registry_[name] = new_class
return new_class
def __call__(cls, name, *args, **kwargs):
if name is None:
raise ValueError("LogicGroup instances must have a 'name'.")
# Check the cache for an existing instance
if name in LGM:
return LGM[name]
# Create a new instance and cache it
instance = super().__call__(name=name, *args, **kwargs)
LGM[name] = instance
return instance
@property
def registry(self):
return self._registry_
[docs]
class LogicGroup(object, metaclass=LogicGroupMeta):
"""
A minimal context manager to save/restore state from the `.contexts` dict.
A logic group maintains no status itself; the status should be restored
from the outer `.contexts` dict.
"""
[docs]
def __init__(self, name: str, parent: Self = None, contexts: dict[str, Any] = None):
self.name = name
self.parent = parent
self.Break = type(f"{self.__class__.__name__}Break", (Exception,), {}) # Assign Break at instance level
# a root logic group
if parent is None:
info_dict = {}
if contexts is None:
contexts = {}
# try to recover from parent
else:
info_dict = parent._sub_logics.setdefault(name, {})
logic_type = self.__class__.__name__
assert info_dict.setdefault('logic_type', logic_type) == logic_type, f"Logic {info_dict['logic_type']} already registered in {parent.name}!"
contexts = info_dict.setdefault('contexts', {} if contexts is None else contexts)
self.contexts: dict[str, Any] = contexts
self._sub_logics = info_dict.setdefault('sub_logics', {})
def __repr__(self):
return f'<{self.__class__.__name__}>({self.name!r})'
def __enter__(self) -> Self:
LGM.enter_logic_group(self)
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
LGM.exit_logic_group(self)
if exc_type is None:
return
if exc_type is self.Break:
return True
# Explicitly re-raise other exceptions
return False
[docs]
def break_(self, scope: LogicGroup = None):
if scope is None:
scope = self
# will not break from scope in inspection mode
if LGM.inspection_mode:
active_node = LGM.active_expression
if active_node is not None:
active_node: LogicNode
if not active_node.nodes:
if LGM.vigilant_mode:
raise TooFewChildren()
else:
LOGGER.warning('Must have at least one action node before breaking from logic group. A NoAction node will be automatically assigned.')
from .node import NoAction
NoAction()
last_node = active_node.last_leaf
assert isinstance(last_node, ActionNode), NodeValueError('An ActionNode is required before breaking a LogicGroup.')
last_node.break_from = scope
LGM._breakpoint_nodes.append(last_node)
return
raise scope.Break()
@property
def sub_logics(self) -> dict[str, Self]:
sub_logic_instances = {}
for logic_name, info in self._sub_logics.items():
logic_type = info["logic_type"]
# Dynamically retrieve the class using meta registry
logic_class = self.__class__.registry.get(logic_type)
if logic_class is None:
raise ValueError(f"Class {logic_type} not found in registry.")
# Get the __init__ method's signature
init_signature = inspect.signature(logic_class.__init__)
init_params = init_signature.parameters
# Prepare arguments for the sub-logic initialization
init_args = {}
for param_name, param in init_params.items():
if param_name == "self":
continue # Skip 'self'
if param_name in info:
init_args[param_name] = info[param_name]
elif param_name == "name":
init_args["name"] = logic_name
elif param_name == "parent":
init_args["parent"] = self
elif param_name == "contexts":
LOGGER.warning(f"Contexts dict not found for {logic_name}!")
init_args["contexts"] = {}
elif param.default == inspect.Parameter.empty:
# Missing a required argument that cannot be inferred
raise TypeError(f"Missing required argument '{param_name}' for {logic_type}.")
# Instantiate the sub-logic
sub_logic_instance = logic_class(**init_args)
sub_logic_instances[logic_name] = sub_logic_instance
return sub_logic_instances
[docs]
class SkipContextsBlock(object):
class _Skip(Exception):
pass
def _entry_check(self) -> Any:
"""
A True value indicating NOT skip.
a False value indicating skip the code block.
"""
pass
@final
def __enter__(self):
if self._entry_check(): # Check if the expression evaluates to True
self._on_enter()
return self
self._original_trace = self.get_trace()
frame = inspect.currentframe().f_back
sys.settrace(self.empty_trace)
frame.f_trace = self.err_trace
@final
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type is None:
self._on_exit()
return
if issubclass(exc_type, self._Skip):
if hasattr(self, '_original_trace'):
sys.settrace(self._original_trace) # Restore the original trace
else:
raise Exception('original_trace not found! Debugger broken! This should never happened.')
return True
self._on_exit()
# Propagate any other exception raised in the block
return False
def _on_enter(self):
pass
def _on_exit(self):
pass
[docs]
@staticmethod
def get_trace():
"""
Safely retrieve the current trace function, prioritizing the PyDev debugger's trace function.
"""
try:
# Check if PyDev debugger is active
# noinspection PyUnresolvedReferences
import pydevd
debugger = pydevd.GetGlobalDebugger()
if debugger is not None:
return debugger.trace_dispatch # Use PyDev's trace function
except ImportError:
pass # PyDev debugger is not installed or active
# Fall back to the standard trace function
return sys.gettrace()
[docs]
@classmethod
def empty_trace(cls, *args, **kwargs) -> None:
pass
[docs]
@classmethod
def err_trace(cls, frame, event, arg):
raise cls._Skip("Expression evaluated to be False, cannot enter the block.")
[docs]
class LogicExpression(SkipContextsBlock):
"""
Represents a logical or mathematical expression that supports deferred evaluation.
"""
[docs]
def __init__(
self,
expression: float | int | bool | Exception | Callable[[], Any],
dtype: type = None,
repr: str = None,
):
"""
Initialize the LogicExpression.
Args:
expression (Union[Any, Callable[[], Any]]): A callable or static value.
dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
repr (str, optional): A string representation of the expression.
"""
self.expression = expression
self.dtype = dtype
self.repr = repr if repr is not None else str(expression)
super().__init__()
def _entry_check(self) -> Any:
return self.eval()
[docs]
def eval(self, enforce_dtype: bool = False) -> Any:
"""Evaluate the expression."""
if isinstance(self.expression, (float, int, bool, str)):
value = self.expression
elif callable(self.expression):
value = self.expression()
elif isinstance(self.expression, Exception):
raise self.expression
else:
raise TypeError(f"Unsupported expression type: {type(self.expression)}.")
if self.dtype is Any or self.dtype is None:
pass # No type enforcement
elif enforce_dtype:
value = self.dtype(value)
elif not isinstance(value, self.dtype):
LOGGER.warning(f"Evaluated value {value} does not match dtype {self.dtype.__name__}.")
return value
# Logical operators
[docs]
@classmethod
def cast(cls, value: int | float | bool | Exception | Self, dtype: type = None) -> Self:
"""
Convert a static value, callable, or error into a LogicExpression.
Args:
value (Union[int, float, bool, LogicExpression, Callable, Exception]):
The value to convert. Can be:
- A static value (int, float, or bool).
- A callable returning a value.
- A pre-existing LogicExpression.
- An Exception to raise during evaluation.
dtype (type, optional): The expected type of the resulting LogicExpression.
If None, it will be inferred from the value.
Returns:
LogicExpression: The resulting LogicExpression.
Raises:
TypeError: If the value type is unsupported or dtype is incompatible.
"""
if isinstance(value, LogicExpression):
return value
if isinstance(value, (int, float, bool)):
return LogicExpression(
expression=value,
dtype=dtype or type(value),
repr=str(value)
)
if callable(value):
return LogicExpression(
expression=value,
dtype=dtype or Any,
repr=f"Eval({value})"
)
if isinstance(value, Exception):
return LogicExpression(
expression=value,
dtype=dtype or Any,
repr=f"Raises({type(value).__name__}: {value})"
)
raise TypeError(f"Unsupported type for LogicExpression conversion: {type(value)}.")
def __bool__(self) -> bool:
return bool(self.eval())
def __and__(self, other: Self | bool) -> Self:
other_expr = self.cast(value=other, dtype=bool)
new_expr = LogicExpression(
expression=lambda: self.eval() and other_expr.eval(),
dtype=bool,
repr=f"({self.repr} and {other_expr.repr})"
)
return new_expr
def __eq__(self, other: int | float | bool | str | Self) -> Self:
if isinstance(other, LogicExpression):
other_value = other.eval()
else:
other_value = other
return LogicExpression(
expression=lambda: self.eval() == other_value,
dtype=bool,
repr=f"({self.repr} == {repr(other_value)})"
)
def __or__(self, other: Self | bool) -> Self:
other_expr = self.cast(value=other, dtype=bool)
new_expr = LogicExpression(
expression=lambda: self.eval() or other_expr.eval(),
dtype=bool,
repr=f"({self.repr} or {other_expr.repr})"
)
return new_expr
# Math operators
@classmethod
def _math_op(cls, self: Self, other: int | float | Self, op: Callable, operator_str: str, dtype: type = None) -> Self:
other_expr = LogicExpression.cast(other)
if dtype is None:
dtype = self.dtype
new_expr = LogicExpression(
expression=lambda: op(self.eval(), other_expr.eval()),
dtype=dtype,
repr=f"({self.repr} {operator_str} {other_expr.repr})",
)
return new_expr
def __add__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.add, operator_str="+")
def __sub__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.sub, operator_str="-")
def __mul__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.mul, operator_str="*")
def __truediv__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.truediv, operator_str="/")
def __floordiv__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.floordiv, operator_str="//")
def __pow__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.pow, operator_str="**")
# Comparison operators, note that __eq__, __ne__ is special and should not implement as math operator
def __lt__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.lt, operator_str="<", dtype=bool)
def __le__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.le, operator_str="<=", dtype=bool)
def __gt__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.gt, operator_str=">", dtype=bool)
def __ge__(self, other: int | float | bool | Self) -> Self:
return self._math_op(self=self, other=other, op=operator.ge, operator_str=">=", dtype=bool)
def __repr__(self) -> str:
return f"LogicExpression(dtype={'Any' if self.dtype is None else self.dtype.__name__}, repr={self.repr})"
class ExpressionCollection(LogicGroup):
def __init__(self, data: Any, name: str, **kwargs):
if 'logic_group' not in kwargs:
logic_group = kwargs.get("logic_group")
else:
logic_group = LGM.active_logic_group
super().__init__(
name=name if name is not None else f'{logic_group.name}.{self.__class__.__name__}',
parent=logic_group
)
self.data = self.contexts.setdefault('data', data)
[docs]
class LogicNode(LogicExpression):
[docs]
def __init__(
self,
expression: float | int | bool | Exception | Callable[[], Any],
dtype: type = None,
repr: str = None,
):
"""
Initialize the LogicExpression.
Args:
expression (Union[Any, Callable[[], Any]]): A callable or static value.
dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
repr (str, optional): A string representation of the expression.
"""
super().__init__(expression=expression, dtype=dtype, repr=repr)
self.labels = [_.name for _ in LGM._active_groups]
self.nodes: dict[Any, LogicNode] = {} # Dict[condition, LogicExpression]
self.parent: LogicNode | None = None
self.edges = [] # list of condition
self.subordinates = [] # all the subordinate nodes initialized inside this node with statement
def _entry_check(self) -> Any:
"""
If `LGM.inspection_mode` is active, always returns `True`.
Which guarantees the entrance the with code block
Returns:
Any: Evaluation result.
"""
if LGM.inspection_mode:
return True
return self.eval()
def __rshift__(self, expression: Self):
"""Overloads >> operator for adding child nodes."""
self.append(expression)
return expression # Allow chaining
def __call__(self, default=None) -> Any:
"""
Recursively evaluates the decision tree starting from this node.
Keyword Args:
default (Any, optional): Fallback value if no matching condition is found.
Returns:
final_value (Any): The evaluated result of the tree.
Raises:
ValueError: If no matching condition is found and no default value is provided.
"""
if default is None:
from .node import NoAction
default = NoAction(auto_connect=False)
if _ins_mode := LGM.inspection_mode:
LOGGER.info('LGM inspection mode temporally disabled to evaluate correctly.')
LGM.inspection_mode = False
_, path = self.eval_recursively(default=default)
LGM.inspection_mode = _ins_mode
if not path:
raise TooFewChildren()
leaf = path[-1]
return leaf.eval()
def __repr__(self):
return f'<{self.__class__.__name__}>({self.repr!r})'
def _on_enter(self):
active_node: LogicNode = LGM.active_expression
if active_node is None:
return LGM.enter_expression(node=self)
match active_node.subordinates:
case []:
active_node.append(expression=self, edge_condition=True)
case [*_, last_node] if not last_node.nodes:
raise TooFewChildren()
case [*_, last_node] if len(last_node.nodes) == 1:
edge_condition = last_node.last_edge
if not isinstance(edge_condition, bool):
raise EdgeValueError(f'{last_node} Edge condition must be a Boolean!')
last_node.append(expression=self, edge_condition=not edge_condition)
case [*_, last_node] if len(last_node.nodes) == 2:
from .node import NoAction
edge_condition, child = last_node.last_edge, last_node.last_node
if not isinstance(child, NoAction):
raise NodeValueError(f'{last_node} second child node must be a NoAction node!')
last_node.pop(-1)
last_node.append(expression=self, edge_condition=edge_condition)
case [*_, last_node] if len(last_node.nodes) > 2:
raise TooManyChildren()
if isinstance(self, ActionNode):
pass
else:
LGM.enter_expression(node=self)
def _on_exit(self):
self.fill_binary_branch(node=self)
LGM.exit_expression(node=self)
[docs]
@classmethod
def fill_binary_branch(cls, node: LogicNode, with_action: ActionNode = None):
"""
Ensures the decision tree node has both True and False branches.
Args:
node (LogicNode): The node to check.
with_action (ActionNode, optional): A default action node to add if missing.
"""
if with_action is None:
from .node import NoAction
with_action = NoAction(auto_connect=False)
if isinstance(node, ActionNode):
return
match len(node.nodes):
case 0:
LOGGER.warning(f"It is rear that {node} having no True branch. Check the <with> statement code block to see if this is intended.")
node.append(expression=with_action, edge_condition=False)
case 1:
edge_condition = node.last_edge
if not isinstance(edge_condition, bool):
raise EdgeValueError(f'{node} Edge condition must be a Boolean!')
node.append(expression=with_action, edge_condition=not edge_condition)
case _:
raise TooManyChildren()
[docs]
@classmethod
def traverse(cls, node: Self, G=None, node_map: dict[int, Self] = None, parent: Self = None, edge_condition: Any = None):
"""
Recursively traverses the decision tree, adding nodes and edges to the graph.
Args:
node (LogicNode): The current node being traversed.
G (networkx.DiGraph, optional): The graph being constructed. Defaults to a new graph.
node_map (dict, optional): A dictionary mapping node IDs to LogicNode instances.
parent (LogicNode, optional): The parent node of the current node.
edge_condition (Any, optional): The condition from parent to this node.
"""
import networkx as nx
if G is None:
G = nx.DiGraph()
if node_map is None:
node_map = {}
node_id = id(node)
# if node_id in node_map:
# return # Avoid duplicate traversal
node_map[node_id] = node
G.add_node(node_id, description=node.repr)
if parent is not None:
edge_label = str(edge_condition) # Use the edge condition from the parent's children list
G.add_edge(id(parent), node_id, label=edge_label)
for edge_condition, child in node.nodes.items():
cls.traverse(node=child, G=G, node_map=node_map, parent=node, edge_condition=edge_condition)
return G, node_map
[docs]
def append(self, expression: LogicNode, edge_condition: Any = None):
"""
Adds a child node to the current node.
Args:
expression (LogicNode): The child node.
edge_condition (Any, optional): The condition for branching.
Raises:
ValueError: If no edge condition is provided.
"""
if edge_condition is None:
edge_condition = NO_CONDITION
if edge_condition is None:
raise ValueError("Child LogicExpression must have an edge condition.")
if edge_condition in self.nodes:
raise ValueError(f"Edge {edge_condition} already exists.")
self.edges.append(edge_condition)
self.nodes[edge_condition] = expression
expression.parent = self
[docs]
def pop(self, index: int = -1) -> tuple[Any, LogicNode]:
edge = self.edges.pop(index)
node = self.nodes.pop(edge)
return edge, node
[docs]
def replace(self, original_node: LogicNode, new_node: LogicNode):
for condition, node in self.nodes.items():
if node is original_node:
break
else:
raise NodeNotFountError()
self.nodes[condition] = new_node
[docs]
def eval_recursively(self, **kwargs):
"""
Recursively evaluates the decision tree starting from this node.
Keyword Args:
path (list, optional): Tracks the decision path during evaluation. Defaults to a new list.
default (Any, optional): Fallback value if no matching condition is found.
Returns:
tuple: (final_value, decision_path)
- final_value (Any): The evaluated result of the tree.
- decision_path (list): The sequence of nodes traversed during evaluation.
Raises:
ValueError: If no matching condition is found and no default value is provided.
"""
if 'path' in kwargs:
path = kwargs['path']
else:
path = [self]
value = self.eval()
if not self.nodes:
return value, path
for condition, child in self.nodes.items():
if condition == value or condition is NO_CONDITION:
return child.eval_recursively(path=path)
if 'default' in kwargs:
default = kwargs['default']
LOGGER.info(f"No matching condition found for value {value} at '{self.repr}', using default {default}.")
return default, path
raise ValueError(f"No matching condition found for value {value} at '{self.repr}'.")
[docs]
def list_labels(self) -> dict[str, list[LogicNode]]:
"""
Lists all logic groups in the tree and returns a dictionary mapping group names to nodes.
"""
labels = {}
def traverse(node):
for group in node.labels:
if group not in labels:
labels[group] = []
labels[group].append(node)
for _, child in node.nodes.items():
traverse(child)
traverse(self)
return labels
[docs]
def select_node(self, label: str) -> LogicNode | None:
"""
Selects the root node of a logic group and validates that the group is chained.
"""
labels = self.list_labels()
if label not in labels:
return None
nodes = labels[label]
root = None
for node in nodes:
if not any(node in child_nodes for _, child_nodes in labels.items() if _ != label):
if root is not None:
raise ValueError(f"Logic group '{label}' has multiple roots.")
root = node
return root
[docs]
def to_html(self, with_group=True, dry_run=True, filename="decision_graph.html", **kwargs):
"""
Visualizes the decision tree using PyVis.
If dry_run=True, shows structure without highlighting active path.
If dry_run=False, evaluates the tree and highlights the decision path.
If with_group=True, uses grouped logic view.
"""
from pyvis.network import Network
G, node_map = self.traverse(self)
# Highlight path if not in dry run
activated_path = []
if not dry_run:
try:
_, path = self.eval_recursively()
activated_path = [id(node) for node in path]
except Exception:
activated_path.clear()
dry_run = True
LOGGER.error(f"Failed to evaluate decision tree.\n{traceback.format_exc()}")
# Visualization using PyVis
net = Network(
height=kwargs.get('height', "750px"),
width=kwargs.get('width', "100%"),
directed=True,
notebook=False,
neighborhood_highlight=True
)
default_color = kwargs.get('default_color', "lightblue")
highlight_color = kwargs.get('highlight_color', "lightgreen")
activated_color = kwargs.get('selected_color', "lightyellow")
dimmed_color = kwargs.get('dimmed_color', "#e0e0e0")
logic_shape = kwargs.get('logic_shape', "box")
action_shape = kwargs.get('action_shape', "ellipse")
original_colors = {}
# Add nodes with group information
for node_id, node in node_map.items():
label = node.repr
title = f"Node: {node.repr}"
# Track the original color for each node
node_color = activated_color if node_id in activated_path else default_color
original_colors[node_id] = node_color
if with_group:
net.add_node(node_id, label=label, title=title, color=node_color, shape=action_shape if isinstance(node, ActionNode) else logic_shape, groups=node.labels)
else:
net.add_node(node_id, label=label, title=title, color=node_color, shape=action_shape if isinstance(node, ActionNode) else logic_shape)
# Add edges
for source, target, data in G.edges(data=True):
edge_label = data.get("label", "")
edge_color = "black" if dry_run else ("green" if source in activated_path and target in activated_path else "black")
net.add_edge(source, target, label=edge_label, title=edge_label, color=edge_color, arrows="to")
# Configure layout and options
options = {
"layout": {
"hierarchical": {
"enabled": True,
"direction": "UD", # UD = Up-Down (root at top, leaves at bottom)
"sortMethod": "directed",
"nodeSpacing": 150,
"levelSeparation": 200
}
},
"physics": {
"hierarchicalRepulsion": {
"centralGravity": 0.0,
"springLength": 200,
"springConstant": 0.01,
"nodeDistance": 200,
"damping": 0.09
},
"minVelocity": 0.75,
"solver": "hierarchicalRepulsion"
},
"nodes": {
"shape": "box",
"shapeProperties": {"borderRadius": 10},
"font": {"size": 14}
},
"edges": {
"color": "black",
"smooth": True
}
}
net.set_options(json.dumps(options))
# Generate the base HTML
html = net.generate_html()
# Inject custom controls and JavaScript
buttons_html = """
<div style="position: absolute; top: 10px; left: 10px; z-index: 1000;
background: rgba(255, 255, 255, 0.9); padding: 12px;
border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.3);
font-family: Arial, sans-serif;">
<h4 style="margin: 0 0 10px; font-size: 16px; text-align: center; color: #333;">
Decision Tree Controls
</h4>
<button onclick="resetColors()" class="control-btn">Reset</button>
"""
if with_group:
groups = {group for node in node_map.values() for group in node.labels}
for group in sorted(groups):
buttons_html += f'<button onclick="highlightGroup(\'{group}\')" class="control-btn">{group}</button>'
buttons_html += "</div>"
js_code = f"""
<script>
function resetColors() {{
// Reset all nodes to their original color and opacity
nodes.forEach(function(node) {{
nodes.update([{{
id: node.id,
color: originalColors[node.id], // Reset to original color
opacity: 1
}}]);
}});
// Reset all edges to default color and opacity
edges.forEach(function(edge) {{
edges.update([{{
id: edge.id,
color: "black",
opacity: 1
}}]);
}});
}}
function highlightGroup(group) {{
// Dim all nodes and edges first
nodes.update([...nodes.getIds().map(id => ({{
id: id,
color: "{dimmed_color}",
opacity: 0.3
}}))]);
edges.update([...edges.getIds().map(id => ({{
id: id,
color: "gray",
opacity: 0.2
}}))]);
// Highlight nodes in the selected group
const groupNodes = nodes.get({{
filter: node => node.groups.includes(group)
}});
nodes.update([...groupNodes.map(node => ({{
id: node.id,
color: "{highlight_color}",
opacity: 1
}}))]);
// Highlight connected edges
const connectedEdges = edges.get({{
filter: edge =>
groupNodes.some(n => n.id === edge.from) ||
groupNodes.some(n => n.id === edge.to)
}});
edges.update([...connectedEdges.map(edge => ({{
id: edge.id,
color: "black",
opacity: 1
}}))]);
}}
// Store the original node colors for reset functionality
const originalColors = {json.dumps(original_colors)};
</script>
"""
# Inject better styles for buttons
css_styles = """
<style>
.control-btn {
background-color: #007BFF;
color: white;
border: none;
padding: 8px 14px;
margin: 5px;
font-size: 14px;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s ease;
}
.control-btn:hover {
background-color: #0056b3;
}
.control-btn:active {
background-color: #003f7f;
}
</style>
"""
# Insert custom elements into the HTML
html = html.replace("</head>", f"{css_styles}</head>")
html = html.replace("</body>", f"{buttons_html}{js_code}</body>")
# Save the modified HTML
with open(filename, "w") as f:
f.write(html)
LOGGER.info(f"Decision tree saved to {filename}")
@property
def children(self) -> Iterable[tuple[Any, LogicNode]]:
"""Returns an iterable of (edge, node) pairs."""
return iter(self.nodes.items())
@property
def leaves(self) -> Iterable[LogicNode]:
"""Recursively finds and returns all leaf nodes (nodes without children)."""
if not self.nodes: # If no children, this node is a leaf
yield self
else:
for _, child in self.nodes.items(): # Recursively get leaves from children
yield from child.leaves
@property
def last_edge(self) -> Any:
return self.edges[-1]
@property
def last_node(self) -> LogicNode:
return self.nodes[self.last_edge]
@property
def last_leaf(self) -> LogicNode:
if not self.nodes:
return self
return self.last_node.last_leaf
@property
def last_leaf_expression(self) -> LogicNode:
last_leaf = self.last_leaf
if isinstance(last_leaf, ActionNode):
return last_leaf.parent
return last_leaf
[docs]
class ActionNode(LogicNode):
[docs]
def __init__(
self,
action: Callable[[], Any] | None = None,
repr: str = None,
auto_connect: bool = True
):
"""
Initialize the LogicExpression.
Args:
action (Union[Any, Callable[[], Any]]): The action to execute.
repr (str, optional): A string representation of the expression.
auto_connect: auto-connect to the current active decision graph.
"""
super().__init__(expression=True, repr=repr)
self.action = action
if auto_connect:
super()._on_enter()
def _on_enter(self):
LOGGER.warning(f'{self.__class__.__name__} should not use with claude')
def _on_exit(self):
pass
def _post_eval(self):
"""
override this method to perform clean up functions.
"""
if self.action is not None:
self.action()
[docs]
def eval_recursively(self, path=None):
"""
Evaluates the decision tree from this node based on the given state.
Returns the final action and records the decision path.
"""
if path is None:
path = []
path.append(self)
value = self.eval()
self._post_eval()
for condition, child in self.nodes.items():
LOGGER.warning(f'{self.__class__.__name__} should not have any sub-nodes.')
if condition == value or condition is NO_CONDITION:
return child.eval_recursively(path=path)
return value, path
[docs]
def append(self, expression: Self, edge_condition: Any = None):
raise TooManyChildren("Cannot append child to an ActionNode!")