Source code for CodeEntropy.entropy.workflow

"""Entropy manager orchestration.

This module defines `EntropyWorkflow`, which coordinates the end-to-end entropy
workflow:
- Determine trajectory bounds and frame count.
- Build a reduced universe based on atom selection.
- Identify molecule groups and hierarchy levels.
- Optionally compute water entropy and adjust selection.
- Execute the level DAG (matrix/state preparation).
- Execute the entropy graph (entropy calculations and aggregation).
- Finalize and persist results.

The manager intentionally delegates calculations to dedicated components.
"""

from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Mapping
from typing import Any

import pandas as pd

from CodeEntropy.core.dask_clusters import HPCDaskManager
from CodeEntropy.core.logging import LoggingConfig
from CodeEntropy.entropy.graph import EntropyGraph
from CodeEntropy.entropy.water import WaterEntropy
from CodeEntropy.levels.hierarchy import HierarchyBuilder
from CodeEntropy.levels.level_dag import LevelDAG
from CodeEntropy.trajectory.frames import FrameSelection
from CodeEntropy.trajectory.source import FrameSource

logger = logging.getLogger(__name__)
console = LoggingConfig.get_console()

SharedData = dict[str, Any]


[docs] class EntropyWorkflow: """Coordinate entropy calculations across structural levels. This class is responsible for orchestration and IO-level concerns (selection, grouping, running graphs, and finalizing results). Domain calculations live in dedicated components (LevelDAG, EntropyGraph, WaterEntropy, etc.). """ def __init__( self, run_manager: Any, args: Any, universe: Any, reporter: Any, group_molecules: Any, dihedral_analysis: Any, universe_operations: Any, ) -> None: """Initialize the entropy workflow manager. Args: run_manager: Manager for universe IO and unit conversions. args: Parsed CLI/user arguments. universe: MDAnalysis Universe representing the simulation system. reporter: Collector for per-molecule and per-residue outputs. group_molecules: Component that groups molecules for averaging. dihedral_analysis: Component used to compute conformational states. (Stored for completeness; computation is typically triggered by nodes.) universe_operations: Adapter providing common universe operations. """ self._run_manager = run_manager self._args = args self._universe = universe self._reporter = reporter self._group_molecules = group_molecules self._dihedral_analysis = dihedral_analysis self._universe_operations = universe_operations
[docs] def execute(self) -> None: """Run the full entropy workflow and emit results. This orchestrates the complete entropy pipeline: 1. Build trajectory frame selection. 2. Apply atom/frame selection to create the current analysis universe. 3. Detect hierarchy levels. 4. Group molecules. 5. Split groups into water and non-water. 6. Optionally compute water entropy. 7. Run level DAG and entropy graph. 8. Finalize and persist results. """ frame_selection = self._build_frame_selection() console.print( f"Analyzing a total of {frame_selection.n_frames} " f"frames in this calculation." ) reduced_universe = self._build_reduced_universe(frame_selection) levels = self._detect_levels(reduced_universe) groups = self._group_molecules.grouping_molecules( reduced_universe, self._args.grouping ) nonwater_groups, water_groups = self._split_water_groups( reduced_universe, groups ) if self._args.water_entropy and water_groups and nonwater_groups: self._compute_water_entropy(frame_selection, water_groups) else: nonwater_groups.update(water_groups) shared_data = self._build_shared_data( reduced_universe=reduced_universe, levels=levels, groups=nonwater_groups, frame_selection=frame_selection, ) self._configure_parallel_frame_execution(shared_data) try: with self._reporter.progress(transient=False) as p: self._run_level_dag(shared_data, progress=p) self._run_entropy_graph(shared_data, progress=p) finally: client = shared_data.get("dask_client") if client is not None: client.close() self._finalize_molecule_results() self._reporter.log_tables()
def _configure_parallel_frame_execution(self, shared_data: SharedData) -> None: """Attach a Dask client to shared_data if parallel frames are requested. Supports: - Local Dask via --parallel_frames true / --use_dask true - SLURM-backed Dask via --hpc true """ use_parallel = bool( getattr(self._args, "parallel_frames", False) or getattr(self._args, "use_dask", False) or getattr(self._args, "hpc", False) ) if not use_parallel: return if "dask_client" in shared_data: shared_data["parallel_frames"] = True return if getattr(self._args, "hpc", False): client = HPCDaskManager(self._args).configure_cluster() shared_data["dask_client"] = client shared_data["parallel_frames"] = True return try: from dask.distributed import Client except ImportError as exc: raise RuntimeError( "Parallel frame execution was requested, but dask.distributed " "is not installed." ) from exc shared_data["dask_client"] = Client( processes=True, n_workers=getattr(self._args, "dask_workers", None), threads_per_worker=getattr(self._args, "dask_threads_per_worker", 1), ) shared_data["parallel_frames"] = True def _build_frame_selection(self) -> FrameSelection: """Build the workflow frame selection. Returns: FrameSelection containing absolute source-trajectory frame indices. Notes: Physical frame slicing is not used. The selected frame indices are the global workflow frame contract and are consumed by FrameSource. """ start, end, step = self._get_trajectory_bounds() return FrameSelection.from_bounds( start=start, stop=end, step=step, ) def _build_shared_data( self, reduced_universe: Any, levels: Any, groups: Mapping[int, Any], frame_selection: FrameSelection, ) -> SharedData: """Build the shared_data dict used by nodes and graphs. Args: reduced_universe: Active analysis universe after atom selection. The trajectory is not physically frame-sliced. levels: Level definition per molecule id. groups: Mapping of group id to molecule ids. frame_selection: Explicit absolute workflow frame selection. Returns: Shared data dictionary for DAG/graph execution. """ frame_source = FrameSource( universe=reduced_universe, selection=frame_selection, ) shared_data: SharedData = { "entropy_manager": self, "run_manager": self._run_manager, "reporter": self._reporter, "args": self._args, "universe": self._universe, "reduced_universe": reduced_universe, "levels": levels, "groups": dict(groups), "start": frame_selection.source_start, "end": frame_selection.source_stop_exclusive, "step": frame_selection.infer_source_step(), "n_frames": frame_selection.n_frames, "frame_selection": frame_selection, "frame_source": frame_source, "frame_indices": list(frame_selection.indices), "source_frame_indices": list(frame_selection.indices), } return shared_data def _run_level_dag( self, shared_data: SharedData, *, progress: object | None = None ) -> None: """Execute the structural/level DAG. Args: shared_data: Shared data dict that will be mutated by the DAG. progress: Optional progress sink provided by ResultsReporter.progress(). """ LevelDAG(self._universe_operations).build().execute( shared_data, progress=progress ) def _run_entropy_graph( self, shared_data: SharedData, *, progress: object | None = None ) -> None: """Execute the entropy calculation graph and merge results into shared_data. Args: shared_data: Shared data dict that will be mutated by the graph. progress: Optional progress sink provided by ResultsReporter.progress(). """ entropy_results = EntropyGraph().build().execute(shared_data, progress=progress) shared_data.update(entropy_results) def _get_trajectory_bounds(self) -> tuple[int, int, int]: """Return validated start, end, and step frame indices from args. Returns: Tuple of ``(start, end, step)``. Raises: ValueError: If the frame window is invalid. """ n_total = len(self._universe.trajectory) start = 0 if self._args.start is None else int(self._args.start) end = ( n_total if self._args.end is None or int(self._args.end) == -1 else int(self._args.end) ) step = 1 if self._args.step is None else int(self._args.step) return start, end, step def _build_reduced_universe(self, frame_selection: FrameSelection) -> Any: """Apply atom selection and return the active analysis universe. Args: frame_selection: Workflow frame selection. Used for validation. Returns: MDAnalysis Universe after atom selection. Frames are not physically sliced; selected-frame access is handled by FrameSource. Raises: ValueError: If no frames are selected. """ if frame_selection.n_frames == 0: raise ValueError("Frame selection is empty.") selection = self._args.selection_string if selection == "all": return self._universe reduced_atoms = self._universe_operations.select_atoms( self._universe, selection, ) name = f"{len(reduced_atoms.trajectory)}_frame_dump_atom_selection" self._run_manager.write_universe(reduced_atoms, name) return reduced_atoms def _detect_levels(self, reduced_universe: Any) -> Any: """Detect hierarchy levels for each molecule in the reduced universe. Args: reduced_universe: Reduced MDAnalysis Universe. Returns: Levels structure as returned by `HierarchyBuilder.select_levels`. """ level_hierarchy = HierarchyBuilder() _number_molecules, levels = level_hierarchy.select_levels(reduced_universe) return levels def _split_water_groups( self, universe: Any, groups: Mapping[int, Any], ) -> tuple[dict[int, Any], dict[int, Any]]: """Partition molecule groups into water and non-water groups. This method identifies which molecule groups correspond to water molecules based on residue membership. Args: universe (Any): The MDAnalysis Universe used to build the molecule groups (typically the reduced_universe). groups (Mapping[int, Any]): Mapping of group_id -> list of molecule fragment indices. Returns: Tuple[Dict[int, Any], Dict[int, Any]]: A tuple containing: - nonwater_groups: Mapping of group_id -> molecule ids that are NOT water. - water_groups: Mapping of group_id -> molecule ids that contain water. """ water_atoms = universe.select_atoms("water") water_resids = {res.resid for res in water_atoms.residues} water_groups = { gid: mol_ids for gid, mol_ids in sorted(groups.items()) if any( res.resid in water_resids for mol in [universe.atoms.fragments[i] for i in mol_ids] for res in mol.residues ) } nonwater_groups = { gid: g for gid, g in sorted(groups.items()) if gid not in water_groups } return nonwater_groups, water_groups def _compute_water_entropy( self, frame_selection: FrameSelection, water_groups: Mapping[int, Any], ) -> None: """Compute water entropy for each water group and adjust selection string. Args: frame_selection: Workflow frame selection. water_groups: Mapping of group id to molecule ids for waters. """ if not water_groups or not self._args.water_entropy: return start = frame_selection.source_start end = frame_selection.source_stop_exclusive step = frame_selection.infer_source_step() if start is None or end is None: return water_entropy = WaterEntropy(self._args, self._reporter) for group_id in water_groups.keys(): water_entropy.calculate_and_log( universe=self._universe, start=start, end=end, step=step, group_id=group_id, ) self._args.selection_string = ( f"{self._args.selection_string} and not water" if self._args.selection_string != "all" else "not water" ) logger.debug("WaterEntropy: molecule_data= %s", self._reporter.molecule_data) logger.debug("WaterEntropy: residue_data= %s", self._reporter.residue_data) def _finalize_molecule_results(self) -> None: """Aggregate group totals and persist results to JSON. Computes total entropy per group and appends "Group Total" rows to the molecule results table, then writes molecule and residue tables to the configured output file via the data logger. """ entropy_by_group = defaultdict(float) for group_id, level, _etype, result in self._reporter.molecule_data: if level == "Group Total": continue try: entropy_by_group[group_id] += float(result) except (TypeError, ValueError): logger.warning("Skipping invalid entry: %s, %s", group_id, result) for group_id, total in sorted(entropy_by_group.items()): self._reporter.molecule_data.append( (group_id, "Group Total", "Group Total Entropy", total) ) molecule_df = pd.DataFrame( self._reporter.molecule_data, columns=["Group ID", "Level", "Type", "Result (J/mol/K)"], ) residue_df = pd.DataFrame( self._reporter.residue_data, columns=[ "Group ID", "Residue Name", "Level", "Type", "Frame Count", "Result (J/mol/K)", ], ) self._reporter.save_dataframes_as_json( molecule_df, residue_df, self._args.output_file, args=self._args, include_raw_tables=False, )