Source code for CodeEntropy.levels.dihedrals
"""Dihedral state assignment for conformational entropy.
This module converts selected-frame dihedral angle time series into discrete
conformational state labels. The resulting state labels are used downstream to
compute configurational entropy.
Frame-index contract:
- ``FrameSelection.analysis_indices`` are used for MDAnalysis trajectory access
in the active analysis universe.
- ``Dihedral(...).run(start, stop, step)`` uses frame bounds in the active
analysis-universe index space.
- ``dihedral_results.results.angles`` is always indexed locally from zero.
Never use an absolute/source frame index directly into that result array.
"""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
from MDAnalysis.analysis.dihedrals import Dihedral
from rich.progress import TaskID
from CodeEntropy.results.reporter import _RichProgressSink
from CodeEntropy.trajectory.frames import FrameSelection
logger = logging.getLogger(__name__)
UAKey = tuple[int, int]
[docs]
class ConformationStateBuilder:
"""Build conformational state labels from selected-frame dihedral angles."""
def __init__(self, universe_operations: Any) -> None:
"""Initialize the analysis helper.
Args:
universe_operations: Object providing helper methods:
- extract_fragment(data_container, molecule_id)
- select_atoms(atomgroup, selection_string)
"""
self._universe_operations = universe_operations
[docs]
def build_conformational_states(
self,
data_container: Any,
levels: dict[Any, list[str]],
groups: dict[int, list[Any]],
bin_width: float,
frame_selection: FrameSelection,
progress: _RichProgressSink | None = None,
) -> tuple[dict[UAKey, list[str]], list[list[str]], dict[UAKey, int], list[int]]:
"""Build conformational state labels from selected trajectory frames.
Args:
data_container: MDAnalysis Universe or compatible container used to
extract fragments and compute dihedral time series.
levels: Mapping of molecule id to enabled level names.
groups: Mapping of group id to molecule ids.
bin_width: Histogram bin width in degrees used when identifying peak
dihedral populations.
frame_selection: FrameSelection controlling which frames are analysed.
During the current migration stage, ``analysis_indices`` are local
indices into the physically frame-sliced analysis universe.
progress: Optional progress sink.
Returns:
Tuple ``(states_ua, states_res, flexible_ua, flexible_res)``.
"""
number_groups = len(groups)
states_ua: dict[UAKey, list[str]] = {}
states_res: list[list[str]] = [[] for _ in range(number_groups)]
flexible_ua: dict[UAKey, int] = {}
flexible_res: list[int] = []
task: TaskID | None = None
if progress is not None:
total = max(1, len(groups))
task = progress.add_task(
"[green]Conformational states",
total=total,
title="Initializing",
)
if not groups:
if progress is not None and task is not None:
progress.update(task, title="No groups")
progress.advance(task)
return states_ua, states_res, flexible_ua, flexible_res
for group_id in groups.keys():
molecules = groups[group_id]
if not molecules:
if progress is not None and task is not None:
progress.update(task, title=f"Group {group_id} (empty)")
progress.advance(task)
continue
if progress is not None and task is not None:
progress.update(task, title=f"Group {group_id}")
level_list = levels[molecules[0]]
peaks_ua, peaks_res = self._identify_peaks(
data_container=data_container,
molecules=molecules,
bin_width=bin_width,
level_list=level_list,
frame_selection=frame_selection,
)
self._assign_states(
data_container=data_container,
group_id=group_id,
molecules=molecules,
level_list=level_list,
peaks_ua=peaks_ua,
peaks_res=peaks_res,
states_ua=states_ua,
states_res=states_res,
flexible_ua=flexible_ua,
flexible_res=flexible_res,
frame_selection=frame_selection,
)
if progress is not None and task is not None:
progress.advance(task)
logger.debug("States UA: %s", states_ua)
logger.debug("Number of flexible dihedrals UA: %s", flexible_ua)
logger.debug("States Res: %s", states_res)
logger.debug("Number of flexible dihedrals Res: %s", flexible_res)
return states_ua, states_res, flexible_ua, flexible_res
def _select_heavy_residue(self, mol: Any, res_id: int) -> Any:
"""Select heavy atoms in a residue by residue index.
Args:
mol: Representative molecule AtomGroup.
res_id: Local residue index.
Returns:
AtomGroup containing heavy atoms in the residue selection.
"""
selection1 = mol.residues[res_id].atoms.indices[0]
selection2 = mol.residues[res_id].atoms.indices[-1]
res_container = self._universe_operations.select_atoms(
mol, f"index {selection1}:{selection2}"
)
return self._universe_operations.select_atoms(res_container, "prop mass > 1.1")
def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]:
"""Return dihedral AtomGroups for a container at a given level.
Args:
data_container: MDAnalysis container.
level: Either ``"united_atom"`` or ``"residue"``.
Returns:
List of AtomGroups, each representing a dihedral definition.
"""
atom_groups: list[Any] = []
if level == "united_atom":
for dihedral in data_container.dihedrals:
atom_groups.append(dihedral.atoms)
if level == "residue":
num_residues = len(data_container.residues)
if num_residues >= 4:
for residue in range(4, num_residues + 1):
atom1 = data_container.select_atoms(
f"resindex {residue - 4} and bonded resindex {residue - 3}"
)
atom2 = data_container.select_atoms(
f"resindex {residue - 3} and bonded resindex {residue - 4}"
)
atom3 = data_container.select_atoms(
f"resindex {residue - 2} and bonded resindex {residue - 1}"
)
atom4 = data_container.select_atoms(
f"resindex {residue - 1} and bonded resindex {residue - 2}"
)
atom_groups.append(atom1 + atom2 + atom3 + atom4)
logger.debug("Level: %s, Dihedrals: %s", level, atom_groups)
return atom_groups
def _identify_peaks(
self,
data_container: Any,
molecules: list[Any],
bin_width: float,
level_list: list[Any],
frame_selection: FrameSelection,
) -> tuple[list[list[Any]], list[Any]]:
"""Identify histogram peaks for each selected-frame dihedral series.
Args:
data_container: MDAnalysis universe.
molecules: Molecule ids in the group.
bin_width: Histogram bin width in degrees.
level_list: Enabled hierarchy levels for the representative molecule.
frame_selection: Selected frames in the active analysis-universe index
space.
Returns:
Tuple of ``(peaks_ua, peaks_res)``.
"""
rep_mol = self._universe_operations.extract_fragment(
data_container, molecules[0]
)
number_frames = self._analysis_frame_count(frame_selection)
num_residues = len(rep_mol.residues)
num_dihedrals_ua: list[int] = [0 for _ in range(num_residues)]
phi_ua: dict[int, Any] = {}
phi_res: dict[int, list[float]] | list[Any] = {}
peaks_ua: list[list[Any]] = [[] for _ in range(num_residues)]
peaks_res: list[Any] = []
num_dihedrals_res = 0
for molecule in molecules:
mol = self._universe_operations.extract_fragment(data_container, molecule)
for level in level_list:
if level == "united_atom":
for res_id in range(num_residues):
heavy_res = self._select_heavy_residue(mol, res_id)
dihedrals = self._get_dihedrals(heavy_res, level)
num_dihedrals_ua[res_id] = len(dihedrals)
if num_dihedrals_ua[res_id] == 0:
phi_ua[res_id] = []
continue
if res_id not in phi_ua or isinstance(phi_ua[res_id], list):
phi_ua[res_id] = {}
dihedral_results = self._run_dihedrals(
dihedrals=dihedrals,
frame_selection=frame_selection,
)
phi_ua[res_id] = self._process_dihedral_phi(
dihedral_results=dihedral_results,
num_dihedrals=num_dihedrals_ua[res_id],
number_frames=number_frames,
phi_values=phi_ua[res_id],
)
elif level == "residue":
dihedrals = self._get_dihedrals(mol, level)
num_dihedrals_res = len(dihedrals)
if num_dihedrals_res == 0:
phi_res = []
continue
if isinstance(phi_res, list):
phi_res = {}
dihedral_results = self._run_dihedrals(
dihedrals=dihedrals,
frame_selection=frame_selection,
)
phi_res = self._process_dihedral_phi(
dihedral_results=dihedral_results,
num_dihedrals=num_dihedrals_res,
number_frames=number_frames,
phi_values=phi_res,
)
logger.debug("phi_ua %s", phi_ua)
logger.debug("phi_res %s", phi_res)
for level in level_list:
if level == "united_atom":
for res_id in range(num_residues):
phi_values = phi_ua.get(res_id)
if not phi_values:
peaks_ua[res_id] = []
else:
peaks_ua[res_id] = self._process_histogram(
num_dihedrals=num_dihedrals_ua[res_id],
phi_values=phi_values,
bin_width=bin_width,
)
elif level == "residue":
if not phi_res:
peaks_res = []
else:
peaks_res = self._process_histogram(
num_dihedrals=num_dihedrals_res,
phi_values=phi_res,
bin_width=bin_width,
)
return peaks_ua, peaks_res
def _process_dihedral_phi(
self,
dihedral_results: Any,
num_dihedrals: int,
number_frames: int,
phi_values: dict[int, list[float]],
) -> dict[int, list[float]]:
"""Collect positive-angle dihedral values from a local result array.
Args:
dihedral_results: Result of ``MDAnalysis.analysis.dihedrals.Dihedral``.
num_dihedrals: Number of dihedrals in the result.
number_frames: Number of local frames in ``dihedral_results``.
phi_values: Existing accumulator mapping dihedral index to values.
Returns:
Updated ``phi_values`` accumulator.
Notes:
``dihedral_results.results.angles`` is indexed locally from zero.
"""
for dihedral_index in range(num_dihedrals):
phi: list[float] = []
for local_i in range(number_frames):
value = dihedral_results.results.angles[local_i][dihedral_index]
if value < 0:
value += 360
phi.append(float(value))
if dihedral_index not in phi_values:
phi_values[dihedral_index] = phi
else:
phi_values[dihedral_index].extend(phi)
return phi_values
def _process_histogram(
self,
num_dihedrals: int,
phi_values: dict[int, list[float]],
bin_width: float,
) -> list[Any]:
"""Find histogram peaks from dihedral angle values.
Args:
num_dihedrals: Number of dihedrals.
phi_values: Mapping from dihedral index to angle values.
bin_width: Histogram bin width in degrees.
Returns:
List of peak lists, one per dihedral.
"""
peak_values = []
for dihedral_index in range(num_dihedrals):
phi = phi_values[dihedral_index]
number_bins = int(360 / bin_width)
popul, bin_edges = np.histogram(a=phi, bins=number_bins, range=(0, 360))
logger.debug("Histogram: %s", popul)
bin_value = [
0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(0, len(popul))
]
peaks = self._find_histogram_peaks(popul=popul, bin_value=bin_value)
peak_values.append(peaks)
logger.debug("Dihedral: %s Peaks: %s", dihedral_index, peaks)
return peak_values
@staticmethod
def _find_histogram_peaks(
popul: np.ndarray[Any, Any], bin_value: list[float]
) -> list[float]:
"""Return convex turning-point peaks from a histogram.
Args:
popul: Histogram bin populations.
bin_value: Histogram bin centre values.
Returns:
List of peak positions.
"""
number_bins = len(popul)
peaks: list[float] = []
for bin_index in range(number_bins):
if popul[bin_index] == 0:
continue
left = popul[bin_index - 1]
right = popul[0] if bin_index == number_bins - 1 else popul[bin_index + 1]
if popul[bin_index] >= left and popul[bin_index] > right:
peaks.append(bin_value[bin_index])
return peaks
def _assign_states(
self,
data_container: Any,
group_id: int,
molecules: list[Any],
level_list: list[Any],
peaks_ua: list[list[Any]],
peaks_res: list[Any],
states_ua: Any,
states_res: Any,
flexible_ua: Any,
flexible_res: Any,
frame_selection: FrameSelection,
) -> None:
"""Assign discrete state labels for selected-frame dihedrals.
Args:
data_container: MDAnalysis universe.
group_id: Molecule group id.
molecules: Molecule ids in the group.
level_list: Enabled hierarchy levels.
peaks_ua: UA-level peaks by residue.
peaks_res: Residue-level peaks.
states_ua: UA state accumulator.
states_res: Residue state accumulator.
flexible_ua: UA flexible-dihedral accumulator.
flexible_res: Residue flexible-dihedral accumulator.
frame_selection: Selected frames in the active analysis-universe index
space.
Returns:
None. Mutates the provided state/flexible accumulators.
"""
rep_mol = self._universe_operations.extract_fragment(
data_container, molecules[0]
)
number_frames = self._analysis_frame_count(frame_selection)
num_residues = len(rep_mol.residues)
state_res = []
flex_res = 0
for molecule in molecules:
mol = self._universe_operations.extract_fragment(data_container, molecule)
for level in level_list:
if level == "united_atom":
for res_id in range(num_residues):
key = (group_id, res_id)
heavy_res = self._select_heavy_residue(mol, res_id)
dihedrals = self._get_dihedrals(heavy_res, level)
num_dihedrals = len(dihedrals)
if num_dihedrals == 0:
states_ua[key] = []
flexible_ua[key] = 0
continue
dihedral_results = self._run_dihedrals(
dihedrals=dihedrals,
frame_selection=frame_selection,
)
states, flexible = self._process_conformations(
peaks=peaks_ua[res_id],
dihedral_results=dihedral_results,
num_dihedrals=num_dihedrals,
number_frames=number_frames,
)
if key not in states_ua:
states_ua[key] = states
flexible_ua[key] = flexible
else:
states_ua[key].extend(states)
flexible_ua[key] = max(flexible_ua[key], flexible)
if level == "residue":
dihedrals = self._get_dihedrals(mol, level)
num_dihedrals = len(dihedrals)
if num_dihedrals == 0:
state_res = []
continue
dihedral_results = self._run_dihedrals(
dihedrals=dihedrals,
frame_selection=frame_selection,
)
states, flexible = self._process_conformations(
peaks=peaks_res,
dihedral_results=dihedral_results,
num_dihedrals=num_dihedrals,
number_frames=number_frames,
)
state_res.extend(states)
flex_res = max(flex_res, flexible)
states_res.append(state_res)
flexible_res.append(flex_res)
def _process_conformations(
self,
peaks: list[Any],
dihedral_results: Any,
num_dihedrals: int,
number_frames: int,
) -> tuple[list[str], int]:
"""Assign conformational state labels from local dihedral results.
Args:
peaks: Histogram peaks.
dihedral_results: Result of ``Dihedral(...).run(...)``.
num_dihedrals: Number of dihedrals.
number_frames: Number of local result frames.
Returns:
Tuple of ``(states, num_flexible)``.
Notes:
``dihedral_results.results.angles`` is indexed locally from zero.
"""
states: list[str] = []
conformations: list[list[Any]] = []
num_flexible = 0
for dihedral_index in range(num_dihedrals):
conformation: list[Any] = []
for local_i in range(number_frames):
value = dihedral_results.results.angles[local_i][dihedral_index]
if value < 0:
value += 360
distances = [abs(value - peak) for peak in peaks[dihedral_index]]
conformation.append(np.argmin(distances))
unique = np.unique(conformation)
if len(unique) > 1:
num_flexible += 1
conformations.append(conformation)
mol_states = [
state
for state in (
"".join(str(int(conformations[d][f])) for d in range(num_dihedrals))
for f in range(number_frames)
)
if state
]
states.extend(mol_states)
return states, num_flexible
def _run_dihedrals(self, dihedrals: list[Any], frame_selection: FrameSelection):
"""Run MDAnalysis dihedral analysis over selected absolute frames.
Args:
dihedrals: Dihedral AtomGroups.
frame_selection: Absolute trajectory frame selection.
Returns:
MDAnalysis Dihedral analysis result.
Notes:
``Dihedral.run(start, stop, step)`` uses absolute trajectory bounds.
The returned ``results.angles`` array is indexed locally from zero.
"""
if not dihedrals:
raise ValueError("Cannot run Dihedral analysis with no dihedrals.")
start, stop, step = self._analysis_run_bounds(frame_selection)
return Dihedral(dihedrals).run(start=start, stop=stop, step=step)
@staticmethod
def _analysis_frame_count(frame_selection: FrameSelection) -> int:
"""Return the number of selected frames."""
return frame_selection.n_frames
@staticmethod
def _analysis_run_bounds(frame_selection: FrameSelection) -> tuple[int, int, int]:
"""Return MDAnalysis run bounds for selected absolute frames.
Args:
frame_selection: Absolute trajectory frame selection.
Returns:
Tuple of ``(start, stop, step)`` in source-trajectory index space.
Raises:
ValueError: If the selection is empty.
"""
start = frame_selection.source_start
stop = frame_selection.source_stop_exclusive
if start is None or stop is None:
raise ValueError("Frame selection is empty.")
return start, stop, frame_selection.infer_source_step()