Source code for CodeEntropy.levels.nodes.axes_topology

"""Build static axes-topology metadata for frame covariance calculations.

This module caches topology-only atom-index relationships needed by customised
axes calculations. The cache avoids repeated MDAnalysis selection parsing inside
the frame-local covariance loop while preserving frame-dependent positions,
forces, centres, axes, torques, and moments of inertia.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any

import numpy as np

logger = logging.getLogger(__name__)

UAKey = tuple[int, int, int]
ResidueKey = tuple[int, int]


[docs] @dataclass(frozen=True) class UAAxesTopology: """Static topology required to compute customised united-atom axes. Attributes: heavy_atom_index: Reduced-universe atom index for the UA heavy atom. ua_atom_indices: Atom indices for the UA heavy atom and its bonded hydrogens/light atoms. ua_all_atom_indices: Atom indices for the UA heavy atom, bonded heavy atoms, and bonded hydrogens/light atoms. bonded_heavy_indices: Heavy atoms bonded to the UA heavy atom. bonded_light_indices: Hydrogens/light atoms bonded to the UA heavy atom. residue_heavy_indices: Heavy atoms in the parent residue. residue_ua_masses: UA masses for heavy atoms in the parent residue. """ heavy_atom_index: int ua_atom_indices: np.ndarray ua_all_atom_indices: np.ndarray bonded_heavy_indices: np.ndarray bonded_light_indices: np.ndarray residue_heavy_indices: np.ndarray residue_ua_masses: np.ndarray
[docs] @dataclass(frozen=True) class ResidueAxesTopology: """Static topology required to compute customised residue axes. Attributes: residue_heavy_indices: Heavy atom indices in the residue. residue_ua_masses: UA masses for heavy atoms in the residue. has_neighbor_bonds: Whether the residue is bonded to a neighbouring residue according to the original customised residue-axis selection. """ residue_heavy_indices: np.ndarray residue_ua_masses: np.ndarray has_neighbor_bonds: bool
[docs] @dataclass(frozen=True) class AxesTopology: """Cached axes topology for frame covariance calculations. Attributes: ua: Mapping from ``(mol_id, local_residue_id, ua_id)`` to cached united-atom axes topology. residue: Mapping from ``(mol_id, local_residue_id)`` to cached residue axes topology. """ ua: dict[UAKey, UAAxesTopology] = field(default_factory=dict) residue: dict[ResidueKey, ResidueAxesTopology] = field(default_factory=dict)
[docs] class BuildAxesTopologyNode: """Build static customised-axes topology before frame covariance execution."""
[docs] def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: """Build cached axes topology and write it into shared data. The cache is only populated when ``args.customised_axes`` is true. When customised axes are disabled, an empty cache is still written so later stages can read ``shared_data["axes_topology"]`` safely. Args: shared_data: Shared workflow data containing ``args`` and, when customised axes are enabled, ``reduced_universe``, ``levels``, and ``beads``. Returns: Dict containing the cached ``axes_topology`` object. """ args = shared_data["args"] topology = AxesTopology() if not bool(getattr(args, "customised_axes", False)): shared_data["axes_topology"] = topology return {"axes_topology": topology} u = shared_data["reduced_universe"] levels = shared_data["levels"] beads = shared_data["beads"] ua_topology: dict[UAKey, UAAxesTopology] = {} residue_topology: dict[ResidueKey, ResidueAxesTopology] = {} fragments = u.atoms.fragments for mol_id, level_list in enumerate(levels): mol = fragments[mol_id] if "residue" in level_list: self._add_residue_topology( mol=mol, mol_id=mol_id, beads=beads, out=residue_topology, ) if "united_atom" in level_list: self._add_ua_topology( u=u, mol=mol, mol_id=mol_id, beads=beads, out=ua_topology, ) topology = AxesTopology(ua=ua_topology, residue=residue_topology) shared_data["axes_topology"] = topology return {"axes_topology": topology}
def _add_residue_topology( self, *, mol: Any, mol_id: int, beads: dict[Any, list[np.ndarray]], out: dict[ResidueKey, ResidueAxesTopology], ) -> None: """Cache static residue axes topology for one molecule. Args: mol: Molecule AtomGroup. mol_id: Molecule index. beads: Bead-index mapping produced by ``BuildBeadsNode``. out: Output residue topology mapping mutated in place. """ bead_key = (mol_id, "residue") bead_idx_list = beads.get(bead_key, []) if not bead_idx_list: return for local_res_i, residue in enumerate(mol.residues): if local_res_i >= len(bead_idx_list): continue residue_atoms = residue.atoms residue_heavy = residue_atoms.select_atoms("mass 2 to 999") residue_heavy_indices = residue_heavy.indices.astype(int, copy=True) residue_ua_masses = np.asarray( self._get_ua_masses_from_topology(residue_atoms), dtype=float, ) has_neighbor_bonds = self._has_neighbor_bonds( mol=mol, local_res_i=local_res_i, ) out[(mol_id, local_res_i)] = ResidueAxesTopology( residue_heavy_indices=residue_heavy_indices, residue_ua_masses=residue_ua_masses, has_neighbor_bonds=has_neighbor_bonds, ) def _add_ua_topology( self, *, u: Any, mol: Any, mol_id: int, beads: dict[Any, list[np.ndarray]], out: dict[UAKey, UAAxesTopology], ) -> None: """Cache static UA axes topology for one molecule. Args: u: Reduced universe used to resolve bead atom-index arrays. mol: Molecule AtomGroup. mol_id: Molecule index. beads: Bead-index mapping produced by ``BuildBeadsNode``. out: Output UA topology mapping mutated in place. """ for local_res_i, residue in enumerate(mol.residues): bead_key = (mol_id, "united_atom", local_res_i) bead_idx_list = beads.get(bead_key, []) if not bead_idx_list: continue residue_atoms = residue.atoms residue_heavy = residue_atoms.select_atoms("prop mass > 1.1") residue_heavy_indices = residue_heavy.indices.astype(int, copy=True) residue_ua_masses = np.asarray( self._get_ua_masses_from_topology(residue_atoms), dtype=float, ) for ua_i, bead_indices in enumerate(bead_idx_list): bead = u.atoms[bead_indices] heavy = bead.select_atoms("prop mass > 1.1") if len(heavy) == 0: logger.warning( "Skipping UA axes topology with no heavy atom: " "mol=%s residue=%s ua=%s", mol_id, local_res_i, ua_i, ) continue heavy_atom = heavy[0] bonded_heavy, bonded_light = self._split_bonded_atoms(heavy_atom) heavy_index = np.asarray([int(heavy_atom.index)], dtype=int) bonded_heavy_indices = bonded_heavy.indices.astype(int, copy=True) bonded_light_indices = bonded_light.indices.astype(int, copy=True) ua_atom_indices = np.concatenate( [heavy_index, bonded_light_indices], axis=0, ) ua_all_atom_indices = np.concatenate( [heavy_index, bonded_heavy_indices, bonded_light_indices], axis=0, ) out[(mol_id, local_res_i, ua_i)] = UAAxesTopology( heavy_atom_index=int(heavy_atom.index), ua_atom_indices=ua_atom_indices, ua_all_atom_indices=ua_all_atom_indices, bonded_heavy_indices=bonded_heavy_indices, bonded_light_indices=bonded_light_indices, residue_heavy_indices=residue_heavy_indices, residue_ua_masses=residue_ua_masses, ) @staticmethod def _has_neighbor_bonds(*, mol: Any, local_res_i: int) -> bool: """Return whether a residue is bonded to neighbouring residues. Args: mol: Molecule AtomGroup used for the original bonded-neighbour selection. local_res_i: Residue index local to ``mol``. Returns: True when the residue has bonded atoms in the previous or next residue according to the original customised residue-axis query. """ index_prev = local_res_i - 1 index_next = local_res_i + 1 atom_set = mol.select_atoms( f"(resindex {index_prev} or resindex {index_next}) " f"and bonded resid {local_res_i}" ) return len(atom_set) > 0 @staticmethod def _split_bonded_atoms(atom: Any) -> tuple[Any, Any]: """Return bonded heavy and light atoms for one atom. Args: atom: MDAnalysis Atom. Returns: Tuple containing bonded heavy atoms and bonded hydrogens/light atoms. """ bonded_atoms = atom.bonded_atoms bonded_heavy = bonded_atoms.select_atoms("mass 2 to 999") bonded_light = bonded_atoms.select_atoms("mass 1 to 1.1") return bonded_heavy, bonded_light @staticmethod def _get_ua_masses_from_topology(atom_group: Any) -> list[float]: """Return UA masses using static bonded atom relationships. Args: atom_group: AtomGroup containing atoms from one residue. Returns: List of UA masses, one for each heavy atom in ``atom_group``. """ ua_masses: list[float] = [] for atom in atom_group: if atom.mass <= 1.1: continue ua_mass = float(atom.mass) bonded_atoms = getattr(atom, "bonded_atoms", None) if bonded_atoms is None: ua_masses.append(ua_mass) continue bonded_h_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") for hydrogen in bonded_h_atoms: ua_mass += float(hydrogen.mass) ua_masses.append(ua_mass) return ua_masses