Source code for CodeEntropy.levels.level_dag

"""Hierarchy-level DAG orchestration and reduction.

This module defines the `LevelDAG`, which coordinates two stages of the hierarchy
workflow:

1) Static stage (runs once):
   - Detect molecules and available resolution levels.
   - Build beads for each (molecule, level) definition.
   - Initialise accumulators used during per-frame reduction.
   - Compute conformational state descriptors required later by entropy nodes.

2) Frame stage (runs for each trajectory frame):
   - Execute the `FrameGraph` to produce frame-local covariance outputs.
   - Reduce frame-local outputs into running (incremental) means.
"""

from __future__ import annotations

import logging
from typing import Any

import networkx as nx
from rich.progress import TaskID

from CodeEntropy.levels.axes import AxesCalculator
from CodeEntropy.levels.frame_dag import FrameGraph
from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode
from CodeEntropy.levels.nodes.beads import BuildBeadsNode
from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode
from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode
from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode
from CodeEntropy.levels.nodes.find_neighbors import ComputeNeighborsNode
from CodeEntropy.results.reporter import _RichProgressSink

logger = logging.getLogger(__name__)


_FRAME_WORKER_EXCLUDED_SHARED_KEYS = {
    "force_covariances",
    "torque_covariances",
    "forcetorque_covariances",
    "frame_counts",
    "forcetorque_counts",
    "force_torque_stats",
    "force_torque_counts",
    "n_frames",
    "entropy_manager",
    "run_manager",
    "reporter",
    "dask_client",
}


def _execute_frame_worker(
    shared_data: dict[str, Any],
    frame_index: int,
    universe_operations: Any | None = None,
) -> tuple[int, Any]:
    """Execute one frame on a Dask worker.

    Args:
        shared_data: Worker-local shared calculation inputs.
        frame_index: Frame index to process.
        universe_operations: Optional universe operations adapter.

    Returns:
        Tuple of frame index and frame-local covariance output.
    """
    frame_dag = FrameGraph(universe_operations=universe_operations).build()
    return int(frame_index), frame_dag.execute_frame(shared_data, int(frame_index))


[docs] class LevelDAG: """Execute hierarchy detection, per-frame covariance calculation, and reduction. The LevelDAG is responsible for: - Running a static DAG (once) to prepare shared inputs. - Running a per-frame DAG (for each frame) to compute frame-local outputs. - Reducing frame-local outputs into shared running means. The reduction performed here is an incremental mean across frames (and across molecules within a group when frame nodes average within-frame first). """ def __init__(self, universe_operations: Any | None = None) -> None: """Initialise a LevelDAG. Args: universe_operations: Optional adapter providing universe operations. Passed to the FrameGraph and the conformational-state node. """ self._universe_operations = universe_operations self._static_graph = nx.DiGraph() self._static_nodes: dict[str, Any] = {} self._frame_dag = FrameGraph(universe_operations=universe_operations)
[docs] def build(self) -> LevelDAG: """Build the static and frame DAG topology. This registers all static nodes and their dependencies, and builds the internal FrameGraph used for per-frame execution. Returns: Self, to allow fluent chaining. """ self._add_static("detect_molecules", DetectMoleculesNode()) self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"]) self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"]) self._add_static( "init_covariance_accumulators", InitCovarianceAccumulatorsNode(), deps=["detect_levels"], ) self._add_static( "compute_conformational_states", ComputeConformationalStatesNode(self._universe_operations), deps=["detect_levels"], ) self._add_static( "find_neighbors", ComputeNeighborsNode(), deps=["detect_levels"] ) self._frame_dag.build() return self
[docs] def execute( self, shared_data: dict[str, Any], *, progress: _RichProgressSink | None = None ) -> dict[str, Any]: """Execute the full hierarchy workflow and mutate shared_data. This method ensures required shared components exist, runs the static stage once, then iterates through trajectory frames to run the per-frame stage and reduce outputs into running means. Args: shared_data: Shared workflow data dict. This mapping is mutated in-place by both static and frame stages. progress: Optional progress sink passed through to nodes and used for per-frame progress reporting when supported. Returns: The same shared_data mapping passed in, after mutation. """ shared_data.setdefault("axes_manager", AxesCalculator()) self._run_static_stage(shared_data, progress=progress) self._run_frame_stage(shared_data, progress=progress) return shared_data
def _run_static_stage( self, shared_data: dict[str, Any], *, progress: _RichProgressSink | None = None ) -> None: """Run all static nodes in dependency order. Nodes are executed in topological order of the static DAG. If a progress object is provided, it is passed to node.run when the node accepts it. Args: shared_data: Shared workflow data dict to be mutated by static nodes. progress: Optional progress sink to pass to nodes that support it. """ for node_name in nx.topological_sort(self._static_graph): node = self._static_nodes[node_name] if progress is not None: try: node.run(shared_data, progress=progress) continue except TypeError: pass node.run(shared_data) def _add_static(self, name: str, node: Any, deps: list[str] | None = None) -> None: """Register a static node and its dependencies in the static DAG. Args: name: Unique node name used in the static DAG. node: Node object exposing a run(shared_data, **kwargs) method. deps: Optional list of upstream node names that must run before this node. Returns: None. Mutates the internal static graph and node registry. """ self._static_nodes[name] = node self._static_graph.add_node(name) for dep in deps or []: self._static_graph.add_edge(dep, name) def _run_frame_stage( self, shared_data: dict[str, Any], *, progress: _RichProgressSink | None = None, ) -> None: """Execute the per-frame DAG stage and reduce frame outputs. This method iterates over explicit frame indices provided by ``shared_data["frame_source"]``. During this migration stage, those indices are local indices into the physically frame-reduced analysis universe. After physical frame slicing is removed, they will be absolute source-trajectory indices. FrameGraph owns trajectory positioning. LevelDAG only chooses which frame indices to process and reduces each frame-local output into shared accumulators. If ``shared_data["dask_client"]`` exists and parallel frame execution is enabled, frame-local outputs are computed on Dask workers and reduced in the parent process. Args: shared_data: Shared data dictionary. Must contain ``frame_source``. progress: Optional progress sink. Returns: None. Mutates ``shared_data`` in-place via reduction. """ frame_source = shared_data["frame_source"] frame_indices = [ int(frame_index) for frame_index in frame_source.iter_indices() ] shared_data["n_frames"] = len(frame_indices) task: TaskID | None = None if progress is not None: task = progress.add_task( "[green]Frame processing", total=len(frame_indices), title="Initializing", ) client = shared_data.get("dask_client") parallel_frames = bool(shared_data.get("parallel_frames", client is not None)) if parallel_frames and client is not None and len(frame_indices) > 1: self._run_frame_stage_dask( shared_data, frame_indices=frame_indices, client=client, progress=progress, task=task, ) return for frame_index in frame_indices: if progress is not None and task is not None: progress.update(task, title=f"Frame {frame_index}") frame_out = self._frame_dag.execute_frame( shared_data, frame_index, ) self._reduce_one_frame(shared_data, frame_out) if progress is not None and task is not None: progress.advance(task) @staticmethod def _make_frame_worker_shared_data(shared_data: dict[str, Any]) -> dict[str, Any]: """Return the subset of shared data required by frame workers. Reduction accumulators and parent orchestration/reporting objects are intentionally excluded because workers should only compute frame-local outputs. """ return { key: value for key, value in shared_data.items() if key not in _FRAME_WORKER_EXCLUDED_SHARED_KEYS } def _run_frame_stage_dask( self, shared_data: dict[str, Any], *, frame_indices: list[int], client: Any, progress: _RichProgressSink | None = None, task: TaskID | None = None, ) -> None: """Execute frame-local DAG tasks in parallel using Dask. Workers return frame-local covariance payloads. The parent process performs all reductions into the shared accumulators. Important: Do not scatter/broadcast worker_shared. It contains stateful objects such as frame_source / universe trajectory state. Broadcasting can reuse mutable state across tasks on the same worker and make frames interfere with one another. """ try: from distributed import as_completed except ImportError as exc: raise RuntimeError( "Parallel frame execution requires dask.distributed to be installed." ) from exc worker_shared = self._make_frame_worker_shared_data(shared_data) futures = [ client.submit( _execute_frame_worker, worker_shared, frame_index, self._universe_operations, pure=False, ) for frame_index in frame_indices ] completed = 0 try: for future in as_completed(futures): frame_index, frame_out = future.result() completed += 1 if progress is not None and task is not None: progress.update(task, title=f"Frame {frame_index}") self._reduce_one_frame(shared_data, frame_out) if progress is not None and task is not None: progress.advance(task) if completed != len(frame_indices): raise RuntimeError( f"Parallel frame execution completed {completed} frames, " f"but expected {len(frame_indices)}." ) except Exception: client.cancel(futures) raise @staticmethod def _incremental_mean(old: Any, new: Any, n: int) -> Any: """Compute an incremental mean. Args: old: Previous running mean (or None for first sample). new: New sample to incorporate. n: 1-based sample count after adding `new`. Returns: Updated running mean. """ if old is None: return new.copy() if hasattr(new, "copy") else new return old + (new - old) / float(n) def _reduce_one_frame( self, shared_data: dict[str, Any], frame_out: dict[str, Any] ) -> None: """Reduce one frame's covariance outputs into shared running means. Args: shared_data: Shared workflow data dict containing accumulators. frame_out: Frame-local covariance outputs produced by FrameGraph. """ self._reduce_force_and_torque(shared_data, frame_out) self._reduce_forcetorque(shared_data, frame_out) def _reduce_force_and_torque( self, shared_data: dict[str, Any], frame_out: dict[str, Any] ) -> None: """Reduce force/torque covariance outputs into shared accumulators. Args: shared_data: Shared workflow data dict containing: - "force_covariances", "torque_covariances": accumulator structures. - "frame_counts": running sample counts for each accumulator slot. - "group_id_to_index": mapping from group id to accumulator index. frame_out: Frame-local outputs containing "force" and "torque" sections. Returns: None. Mutates accumulator values and counts in shared_data in-place. """ f_cov = shared_data["force_covariances"] t_cov = shared_data["torque_covariances"] counts = shared_data["frame_counts"] gid2i = shared_data["group_id_to_index"] f_frame = frame_out["force"] t_frame = frame_out["torque"] for key, F in f_frame["ua"].items(): counts["ua"][key] = counts["ua"].get(key, 0) + 1 n = counts["ua"][key] f_cov["ua"][key] = self._incremental_mean(f_cov["ua"].get(key), F, n) for key, T in t_frame["ua"].items(): if key not in counts["ua"]: counts["ua"][key] = counts["ua"].get(key, 0) + 1 n = counts["ua"][key] t_cov["ua"][key] = self._incremental_mean(t_cov["ua"].get(key), T, n) for gid, F in f_frame["res"].items(): gi = gid2i[gid] counts["res"][gi] += 1 n = counts["res"][gi] f_cov["res"][gi] = self._incremental_mean(f_cov["res"][gi], F, n) for gid, T in t_frame["res"].items(): gi = gid2i[gid] if counts["res"][gi] == 0: counts["res"][gi] += 1 n = counts["res"][gi] t_cov["res"][gi] = self._incremental_mean(t_cov["res"][gi], T, n) for gid, F in f_frame["poly"].items(): gi = gid2i[gid] counts["poly"][gi] += 1 n = counts["poly"][gi] f_cov["poly"][gi] = self._incremental_mean(f_cov["poly"][gi], F, n) for gid, T in t_frame["poly"].items(): gi = gid2i[gid] if counts["poly"][gi] == 0: counts["poly"][gi] += 1 n = counts["poly"][gi] t_cov["poly"][gi] = self._incremental_mean(t_cov["poly"][gi], T, n) def _reduce_forcetorque( self, shared_data: dict[str, Any], frame_out: dict[str, Any] ) -> None: """Reduce combined force-torque covariance outputs into shared accumulators. Args: shared_data: Shared workflow data dict containing: - "forcetorque_covariances": accumulator structures. - "forcetorque_counts": running sample counts for each accumulator slot. - "group_id_to_index": mapping from group id to accumulator index. frame_out: Frame-local outputs that may include a "forcetorque" section. Returns: None. Mutates accumulator values and counts in shared_data in-place. """ if "forcetorque" not in frame_out: return ft_cov = shared_data["forcetorque_covariances"] ft_counts = shared_data["forcetorque_counts"] gid2i = shared_data["group_id_to_index"] ft_frame = frame_out["forcetorque"] for gid, M in ft_frame.get("res", {}).items(): gi = gid2i[gid] ft_counts["res"][gi] += 1 n = ft_counts["res"][gi] ft_cov["res"][gi] = self._incremental_mean(ft_cov["res"][gi], M, n) for gid, M in ft_frame.get("poly", {}).items(): gi = gid2i[gid] ft_counts["poly"][gi] += 1 n = ft_counts["poly"][gi] ft_cov["poly"][gi] = self._incremental_mean(ft_cov["poly"][gi], M, n)