Source code for CodeEntropy.levels.dihedrals

"""Dihedral state assignment for conformational entropy.

This module converts dihedral angle time series into discrete conformational
state labels. The resulting state labels are used downstream to compute
conformational entropy.
"""

from __future__ import annotations

import logging
from typing import Any

import numpy as np
from MDAnalysis.analysis.dihedrals import Dihedral
from rich.progress import TaskID

from CodeEntropy.results.reporter import _RichProgressSink

logger = logging.getLogger(__name__)

UAKey = tuple[int, int]


[docs] class ConformationStateBuilder: """Build conformational state labels from dihedral angles.""" def __init__(self, universe_operations: Any) -> None: """Initializes the analysis helper. Args: universe_operations: Object providing helper methods: - extract_fragment(data_container, molecule_id) - select_atoms(atomgroup, selection_string) """ self._universe_operations = universe_operations
[docs] def build_conformational_states( self, data_container: Any, levels: dict[Any, list[str]], groups: dict[int, list[Any]], bin_width: float, progress: _RichProgressSink | None = None, ) -> tuple[dict[UAKey, list[str]], list[list[str]], dict[UAKey, int], list[int]]: """Build conformational state labels from trajectory dihedrals. This method constructs discrete conformational state descriptors used in configurational entropy calculations. It supports united-atom (UA) and residue-level state generation depending on which hierarchy levels are enabled per molecule. Progress reporting is optional and UI-agnostic. If a progress sink is provided, the method will create a single task and advance it once per molecule group. Args: data_container: MDAnalysis Universe (or compatible container) used to extract fragments and compute dihedral time series. levels: Mapping of molecule_id -> iterable of enabled level names (e.g., ["united_atom", "residue"]). groups: Mapping of group_id -> list of molecule_ids. bin_width: Histogram bin width in degrees used when identifying peak dihedral populations. progress: Optional progress sink (e.g., from ResultsReporter.progress()). Must expose add_task(), update(), and advance(). Returns: tuple: (states_ua, states_res, flexible_ua, flexible_res) - states_ua: Dict mapping (group_id, local_residue_id) -> list of state labels (strings) across the analyzed trajectory. - states_res: Structure indexed by group_id (or equivalent) containing residue-level state labels (strings) across the analyzed trajectory. Notes: - This function advances progress once per group_id. helpers as implemented in this module. """ number_groups = len(groups) states_ua: dict[UAKey, list[str]] = {} # states_res: list[list[str]] = [[]] states_res: list[list[str]] = [[] for _ in range(number_groups)] flexible_ua: dict[UAKey, int] = {} flexible_res: list[int] = [] task: TaskID | None = None if progress is not None: total = max(1, len(groups)) task = progress.add_task( "[green]Conformational states", total=total, title="Initializing", ) if not groups: if progress is not None and task is not None: progress.update(task, title="No groups") progress.advance(task) return states_ua, states_res for group_id in groups.keys(): molecules = groups[group_id] if not molecules: if progress is not None and task is not None: progress.update(task, title=f"Group {group_id} (empty)") progress.advance(task) continue if progress is not None and task is not None: progress.update(task, title=f"Group {group_id}") peaks_ua, peaks_res = self._identify_peaks( data_container=data_container, molecules=molecules, bin_width=bin_width, level_list=levels[molecules[0]], ) self._assign_states( data_container=data_container, group_id=group_id, molecules=molecules, level_list=levels[molecules[0]], peaks_ua=peaks_ua, peaks_res=peaks_res, states_ua=states_ua, states_res=states_res, flexible_ua=flexible_ua, flexible_res=flexible_res, ) if progress is not None and task is not None: progress.advance(task) logger.debug(f"States UA: {states_ua}") logger.debug(f"Number of flexible dihedrals UA: {flexible_ua}") logger.debug(f"States Res: {states_res}") logger.debug(f"Number of flexible dihedrals Res: {flexible_res}") return states_ua, states_res, flexible_ua, flexible_res
def _select_heavy_residue(self, mol: Any, res_id: int) -> Any: """Select heavy atoms in a residue by residue index. Args: mol: Representative molecule AtomGroup. res_id: Residue index. Returns: AtomGroup containing heavy atoms in the residue selection. """ selection1 = mol.residues[res_id].atoms.indices[0] selection2 = mol.residues[res_id].atoms.indices[-1] res_container = self._universe_operations.select_atoms( mol, f"index {selection1}:{selection2}" ) return self._universe_operations.select_atoms(res_container, "prop mass > 1.1") def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]: """Return dihedral AtomGroups for a container at a given level. Args: data_container: MDAnalysis container (AtomGroup/Universe). level: Either "united_atom" or "residue". Returns: List of AtomGroups (each representing a dihedral definition). """ atom_groups: list[Any] = [] if level == "united_atom": dihedrals = data_container.dihedrals for d in dihedrals: atom_groups.append(d.atoms) if level == "residue": num_residues = len(data_container.residues) if num_residues >= 4: for residue in range(4, num_residues + 1): atom1 = data_container.select_atoms( f"resindex {residue - 4} and bonded resindex {residue - 3}" ) atom2 = data_container.select_atoms( f"resindex {residue - 3} and bonded resindex {residue - 4}" ) atom3 = data_container.select_atoms( f"resindex {residue - 2} and bonded resindex {residue - 1}" ) atom4 = data_container.select_atoms( f"resindex {residue - 1} and bonded resindex {residue - 2}" ) atom_groups.append(atom1 + atom2 + atom3 + atom4) logger.debug(f"Level: {level}, Dihedrals: {atom_groups}") return atom_groups def _identify_peaks( self, data_container: Any, molecules: list[Any], bin_width: float, level_list: list[Any], ) -> list[list[float]]: """Identify histogram peaks ("convex turning points") for each dihedral. Important: This function intentionally preserves the legacy behavior: it samples over the full trajectory length for each molecule and does not apply start/end/step to the Dihedral run. Args: data_container: MDAnalysis universe. molecules: Molecule ids in the group. levels: Molecule levels. bin_width: Histogram bin width (degrees). Returns: List of peaks per dihedral (peak_values[dihedral_index] -> list of peaks). """ rep_mol = self._universe_operations.extract_fragment( data_container, molecules[0] ) number_frames = len(rep_mol.trajectory) num_residues = len(rep_mol.residues) num_dihedrals_ua: list[Any] = [0 for _ in range(num_residues)] phi_ua = {} phi_res: dict[list, list[float]] = {} peaks_ua: list[list[Any]] = [[] for _ in range(num_residues)] peaks_res: list[Any] = [] for molecule in molecules: mol = self._universe_operations.extract_fragment(data_container, molecule) for level in level_list: if level == "united_atom": for res_id in range(num_residues): heavy_res = self._select_heavy_residue(mol, res_id) dihedrals = self._get_dihedrals(heavy_res, level) num_dihedrals_ua[res_id] = len(dihedrals) if num_dihedrals_ua[res_id] == 0: # No dihedrals, no peaks phi_ua[res_id] = [] else: if res_id not in phi_ua: phi_ua[res_id] = {} dihedral_results = Dihedral(dihedrals).run() phi_ua[res_id] = self._process_dihedral_phi( dihedral_results, num_dihedrals_ua[res_id], number_frames, phi_ua[res_id], ) elif level == "residue": dihedrals = self._get_dihedrals(mol, level) num_dihedrals_res = len(dihedrals) if num_dihedrals_res == 0: # No dihedrals, no peaks phi_res = [] else: dihedral_results = Dihedral(dihedrals).run() phi_res = self._process_dihedral_phi( dihedral_results, num_dihedrals_res, number_frames, phi_res, ) logger.debug(f"phi_ua {phi_ua}") logger.debug(f"phi_res {phi_res}") for level in level_list: if level == "united_atom": for res_id in range(num_residues): if phi_ua[res_id] is None: peaks_ua[res_id] = [] else: peaks_ua[res_id] = self._process_histogram( num_dihedrals_ua[res_id], phi_ua[res_id], bin_width ) elif level == "residue": if phi_res is None: peaks_res = [] else: peaks_res = self._process_histogram( num_dihedrals_res, phi_res, bin_width ) return peaks_ua, peaks_res def _process_dihedral_phi( self, dihedral_results, num_dihedrals, number_frames, phi_values, ): """ Find array of dihedral angle values. Args: dihedral_results: the result of MDAnalysis Dihedrals.run. num_dihedrals: the number of dihedrals in the molecule or residue. Returns: peaks """ for dihedral_index in range(num_dihedrals): phi: list[float] = [] for timestep in range(number_frames): value = dihedral_results.results.angles[timestep][dihedral_index] if value < 0: value += 360 phi.append(float(value)) if dihedral_index not in phi_values: phi_values[dihedral_index] = phi else: phi_values[dihedral_index].extend(phi) return phi_values def _process_histogram( self, num_dihedrals, phi_values, bin_width, ): """ Find peaks from array of dihedral angle values. Args: dihedral_results: the result of MDAnalysis Dihedrals.run. num_dihedrals: the number of dihedrals in the molecule or residue. Returns: peaks """ peak_values = [] for dihedral_index in range(num_dihedrals): phi = phi_values[dihedral_index] number_bins = int(360 / bin_width) popul, bin_edges = np.histogram(a=phi, bins=number_bins, range=(0, 360)) logger.debug(f"Histogram: {popul}") bin_value = [ 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(0, len(popul)) ] peaks = self._find_histogram_peaks(popul=popul, bin_value=bin_value) peak_values.append(peaks) logger.debug(f"Dihedral: {dihedral_index} Peaks: {peaks}") return peak_values @staticmethod def _find_histogram_peaks( popul: np.ndarray[Any, Any], bin_value: list[float] ) -> list[float]: """Return convex turning-point peaks from a histogram. The selection of the population of the right adjacent bin takes into account that the dihedral angles are circular. Args: popul: the array of counts for each bin bin_value: the array of dihedral angle value at the center of each bin. Returns: peaks: list of values associated with peaks. """ number_bins = len(popul) peaks: list[float] = [] for bin_index in range(number_bins): if popul[bin_index] == 0: continue left = popul[bin_index - 1] right = popul[0] if bin_index == number_bins - 1 else popul[bin_index + 1] if popul[bin_index] >= left and popul[bin_index] > right: peaks.append(bin_value[bin_index]) return peaks def _assign_states( self, data_container: Any, group_id: int, molecules: list[Any], level_list: list[Any], peaks_ua: list[list[Any]], peaks_res: list[Any], states_ua: Any, states_res: Any, flexible_ua: Any, flexible_res: Any, ) -> list[str]: """Assign discrete state labels for the provided dihedrals. Important: This function intentionally preserves the legacy behavior: it samples over the full trajectory length for each molecule and does not apply start/end/step to the Dihedral run. Args: data_container: MDAnalysis universe. molecules: Molecule ids in the group. dihedrals: Dihedral AtomGroups. peaks: Peaks per dihedral. Returns: List of state labels (strings). """ rep_mol = self._universe_operations.extract_fragment( data_container, molecules[0] ) number_frames = len(rep_mol.trajectory) num_residues = len(rep_mol.residues) state_res = [] flex_res = 0 for molecule in molecules: mol = self._universe_operations.extract_fragment(data_container, molecule) for level in level_list: if level == "united_atom": for res_id in range(num_residues): key = (group_id, res_id) heavy_res = self._select_heavy_residue(mol, res_id) dihedrals = self._get_dihedrals(heavy_res, level) num_dihedrals = len(dihedrals) if num_dihedrals == 0: # No dihedrals, no conformations states_ua[key] = [] flexible_ua[key] = 0 else: dihedral_results = Dihedral(dihedrals).run() states, flexible = self._process_conformations( peaks_ua[res_id], dihedral_results, num_dihedrals, number_frames, ) if key not in states_ua: states_ua[key] = states flexible_ua[key] = flexible else: states_ua[key].extend(states) flexible_ua[key] = max(flexible_ua[key], flexible) if level == "residue": dihedrals = self._get_dihedrals(mol, level) num_dihedrals = len(dihedrals) if num_dihedrals == 0: # No dihedrals, no conformations state_res = [] else: dihedral_results = Dihedral(dihedrals).run() states, flexible = self._process_conformations( peaks_res, dihedral_results, num_dihedrals, number_frames, ) state_res.extend(states) flex_res = max(flex_res, flexible) states_res.append(state_res) flexible_res.append(flex_res) def _process_conformations( self, peaks, dihedral_results, num_dihedrals, number_frames ): """ Find conformations Args: peaks: Histogram peaks. num_dihedrals: Number of dihedral angles in the molecule or residue. Returns: conformations """ states: list[list[Any]] = [] conformations: list[list[Any]] = [] num_flexible = 0 for dihedral_index in range(num_dihedrals): conformation: list[Any] = [] # Check for flexible dihedrals # if len(peaks[dihedral_index]) > 1: # num_flexible += 1 # Get conformations for timestep in range(number_frames): value = dihedral_results.results.angles[timestep][dihedral_index] # We want postive values in range 0 to 360 to make # the peak assignment. # works using the fact that dihedrals have circular symmetry # (i.e. -15 degrees = +345 degrees) if value < 0: value += 360 # Find the peak closest to the dihedral value distances = [abs(value - peak) for peak in peaks[dihedral_index]] conformation.append(np.argmin(distances)) unique = np.unique(conformation) if len(unique) > 1: num_flexible += 1 conformations.append(conformation) # Concatenate all the dihedrals in the molecule into the state # for the frame. mol_states = [ state for state in ( "".join(str(int(conformations[d][f])) for d in range(num_dihedrals)) for f in range(number_frames) ) if state ] states.extend(mol_states) return states, num_flexible