Source code for CodeEntropy.levels.nodes.beads

"""Build bead (AtomGroup index) definitions for each molecule and hierarchy level.

This module defines the `BuildBeadsNode`, a static DAG node that constructs bead
definitions once, in reduced-universe index space. These bead definitions are
used by later frame-level nodes (e.g., covariance construction) without needing
to re-run selection logic every frame.

Beads are stored as arrays of atom indices (in the reduced universe).
"""

from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import MutableMapping
from dataclasses import dataclass
from typing import Any

import numpy as np

from CodeEntropy.levels.hierarchy import HierarchyBuilder

logger = logging.getLogger(__name__)

BeadKey = tuple[int, str] | tuple[int, str, int]
BeadsMap = dict[BeadKey, list[np.ndarray]]


[docs] @dataclass(frozen=True) class UnitedAtomBead: """A united-atom bead associated with a residue bucket. Attributes: residue_id: Local residue index within the molecule (0..n_residues-1). atom_indices: Atom indices (reduced-universe index space) belonging to the bead. """ residue_id: int atom_indices: np.ndarray
[docs] class BuildBeadsNode: """Build bead definitions once, in reduced-universe index space. Output contract: Writes `shared_data["beads"]` with keys: - (mol_id, "united_atom", res_id) -> list[np.ndarray] - (mol_id, "residue") -> list[np.ndarray] - (mol_id, "polymer") -> list[np.ndarray] Notes: United-atom beads are generated at the molecule level (preserving the underlying ordering provided by `HierarchyBuilder.get_beads`) and then grouped into residue buckets based on the heavy atom that defines the bead. """ def __init__(self, hierarchy: HierarchyBuilder | None = None) -> None: """Initialize the node. Args: hierarchy: Optional `HierarchyBuilder` dependency. If not provided, a default instance is created. """ self._hier = hierarchy or HierarchyBuilder()
[docs] def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: """Build bead definitions for all molecules and levels. Args: shared_data: Shared data dictionary. Requires: - "reduced_universe": MDAnalysis.Universe - "levels": list[list[str]] Returns: Dict containing the "beads" mapping (also written into shared_data). Raises: KeyError: If required keys are missing from `shared_data`. """ u = shared_data["reduced_universe"] levels: list[list[str]] = shared_data["levels"] beads: BeadsMap = {} fragments = u.atoms.fragments for mol_id, level_list in enumerate(levels): mol = fragments[mol_id] if "united_atom" in level_list: self._add_united_atom_beads(beads=beads, mol_id=mol_id, mol=mol) if "residue" in level_list: self._add_residue_beads(beads=beads, mol_id=mol_id, mol=mol) if "polymer" in level_list: self._add_polymer_beads(beads=beads, mol_id=mol_id, mol=mol) shared_data["beads"] = beads return {"beads": beads}
def _add_united_atom_beads( self, beads: MutableMapping[BeadKey, list[np.ndarray]], mol_id: int, mol ) -> None: """Compute and store united-atom beads grouped into residue buckets. Args: beads: Output bead mapping mutated in-place. mol_id: Molecule (fragment) index. mol: MDAnalysis AtomGroup representing the molecule. """ ua_beads = self._hier.get_beads(mol, "united_atom") buckets: defaultdict[int, list[np.ndarray]] = defaultdict(list) for bead_i, bead in enumerate(ua_beads): atom_indices = self._validate_bead_indices( bead, mol_id=mol_id, level="united_atom", bead_i=bead_i ) if atom_indices is None: continue residue_id = self._infer_local_residue_id(mol=mol, bead=bead) buckets[residue_id].append(atom_indices) for local_res_id in range(len(mol.residues)): beads[(mol_id, "united_atom", local_res_id)] = buckets.get(local_res_id, []) def _add_residue_beads( self, beads: MutableMapping[BeadKey, list[np.ndarray]], mol_id: int, mol ) -> None: """Compute and store residue beads. Args: beads: Output bead mapping mutated in-place. mol_id: Molecule (fragment) index. mol: MDAnalysis AtomGroup representing the molecule. """ res_beads = self._hier.get_beads(mol, "residue") kept: list[np.ndarray] = [] for bead_i, bead in enumerate(res_beads): atom_indices = self._validate_bead_indices( bead, mol_id=mol_id, level="residue", bead_i=bead_i ) if atom_indices is None: continue kept.append(atom_indices) beads[(mol_id, "residue")] = kept if len(kept) == 0: logger.error( "[BuildBeadsNode] No residue beads kept for mol=%s. Residue-level " "entropy may be 0.0.", mol_id, ) def _add_polymer_beads( self, beads: MutableMapping[BeadKey, list[np.ndarray]], mol_id: int, mol ) -> None: """Compute and store polymer beads. Args: beads: Output bead mapping mutated in-place. mol_id: Molecule (fragment) index. mol: MDAnalysis AtomGroup representing the molecule. """ poly_beads = self._hier.get_beads(mol, "polymer") kept: list[np.ndarray] = [] for bead_i, bead in enumerate(poly_beads): atom_indices = self._validate_bead_indices( bead, mol_id=mol_id, level="polymer", bead_i=bead_i ) if atom_indices is None: continue kept.append(atom_indices) beads[(mol_id, "polymer")] = kept @staticmethod def _validate_bead_indices( bead, mol_id: int, level: str, bead_i: int ) -> np.ndarray | None: """Return a bead's atom indices, or None if the bead is empty. Args: bead: MDAnalysis AtomGroup representing the bead. mol_id: Molecule id used only for logging context. level: Level name used only for logging context. bead_i: Bead index used only for logging context. Returns: A copy of the bead indices as a NumPy array, or None if the bead is empty. """ if len(bead) == 0: logger.warning( "[BuildBeadsNode] Empty bead skipped: mol=%s level=%s bead_i=%s", mol_id, level, bead_i, ) return None return bead.indices.copy() @staticmethod def _infer_local_residue_id(mol, bead) -> int: """Infer the local residue bucket for a united-atom bead. Strategy: - Select heavy atoms in the bead (mass > 1.1). - Use the first heavy atom's `resindex` (universe-level). - Map that universe-level `resindex` back to the molecule's local residue index by scanning `mol.residues`. Args: mol: Molecule AtomGroup. bead: United-atom bead AtomGroup. Returns: Local residue index in [0, len(mol.residues) - 1]. Falls back to 0 if mapping cannot be determined. """ heavy = bead.select_atoms("prop mass > 1.1") if len(heavy) == 0: return 0 target_resindex = int(heavy[0].resindex) for local_i, res in enumerate(mol.residues): if int(res.resindex) == target_resindex: return local_i return 0