Source code for CodeEntropy.levels.dihedrals.state_assignment
"""Conformational state assignment from dihedral peak definitions.
This module contains the logic for converting positive-angle dihedral arrays and
global peak definitions into state labels and flexible-dihedral counts.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import numpy as np
from CodeEntropy.levels.dihedrals.angle_observations import (
ConformationChunkTask,
DihedralAngleObservable,
)
from CodeEntropy.levels.dihedrals.kernels import (
assign_peak_labels_and_count_flexible,
)
from CodeEntropy.levels.dihedrals.topology import MoleculeDihedralTopology
UAKey = tuple[int, int]
[docs]
@dataclass
class ConformationStateData:
"""Serial conformational state data calculated for one molecule group.
Attributes:
state_res: Residue-level state labels for the group.
flex_res: Number of flexible residue-level dihedrals for the group.
states_ua_updates: United-atom state-label updates by ``(group, residue)``.
flexible_ua_updates: United-atom flexible-dihedral updates by
``(group, residue)``.
"""
state_res: list[str]
flex_res: int
states_ua_updates: dict[UAKey, list[str]]
flexible_ua_updates: dict[UAKey, int]
[docs]
@dataclass
class ConformationStatePartial:
"""Chunk-local conformational state labels and flexible counts.
Attributes:
task: Source molecule/frame chunk task.
state_res: Residue-level state labels for this chunk.
flex_res: Number of flexible residue-level dihedrals for this chunk.
states_ua_updates: United-atom state-label updates by ``(group, residue)``.
flexible_ua_updates: United-atom flexible-dihedral updates by
``(group, residue)``.
"""
task: ConformationChunkTask
state_res: list[str]
flex_res: int
states_ua_updates: dict[UAKey, list[str]]
flexible_ua_updates: dict[UAKey, int]
[docs]
class ConformationStateAssigner:
"""Assign conformational state labels from global dihedral peak definitions."""
def _assign_state_partial_from_observable(
self,
observable: DihedralAngleObservable,
topology: MoleculeDihedralTopology,
level_list: list[Any],
peaks_ua: list[list[Any]],
peaks_res: list[Any],
) -> ConformationStatePartial:
"""Assign chunk-local states from cached angle arrays and global peaks.
Args:
observable: Chunk-local angle observable.
topology: Static topology for the observable molecule.
level_list: Enabled hierarchy levels.
peaks_ua: Global united-atom peaks by residue.
peaks_res: Global residue-level peaks.
Returns:
Chunk-local state partial.
"""
state_res: list[str] = []
flex_res = 0
states_ua_updates: dict[UAKey, list[str]] = {}
flexible_ua_updates: dict[UAKey, int] = {}
if "united_atom" in level_list:
for res_id in range(topology.num_residues):
key = (topology.group_id, res_id)
angles = observable.ua_angles_by_residue.get(res_id)
if angles is None or angles.shape[1] == 0:
states_ua_updates[key] = []
flexible_ua_updates[key] = 0
continue
states, flexible = self._process_conformations_from_angles(
peaks=peaks_ua[res_id],
angles=angles,
)
states_ua_updates[key] = states
flexible_ua_updates[key] = flexible
if "residue" in level_list and observable.residue_angles is not None:
if observable.residue_angles.shape[1] > 0:
state_res, flex_res = self._process_conformations_from_angles(
peaks=peaks_res,
angles=observable.residue_angles,
)
return ConformationStatePartial(
task=observable.task,
state_res=state_res,
flex_res=flex_res,
states_ua_updates=states_ua_updates,
flexible_ua_updates=flexible_ua_updates,
)
def _reduce_state_partials(
self,
partials: list[ConformationStatePartial],
) -> ConformationStateData:
"""Merge chunk-local state partials into one group-level result.
Args:
partials: Chunk-local state partials for one group.
Returns:
Group-level state data using deterministic molecule/chunk ordering.
"""
ordered_partials = sorted(
partials,
key=lambda partial: (
partial.task.molecule_order,
partial.task.chunk_id,
),
)
state_res: list[str] = []
flex_res = 0
states_ua_updates: dict[UAKey, list[str]] = {}
flexible_ua_updates: dict[UAKey, int] = {}
for partial in ordered_partials:
for key, states in partial.states_ua_updates.items():
if key not in states_ua_updates:
states_ua_updates[key] = list(states)
flexible_ua_updates[key] = partial.flexible_ua_updates[key]
else:
states_ua_updates[key].extend(states)
flexible_ua_updates[key] = max(
flexible_ua_updates[key],
partial.flexible_ua_updates[key],
)
state_res.extend(partial.state_res)
flex_res = max(flex_res, partial.flex_res)
return ConformationStateData(
state_res=state_res,
flex_res=flex_res,
states_ua_updates=states_ua_updates,
flexible_ua_updates=flexible_ua_updates,
)
@staticmethod
def _merge_group_state_data(
state_data: ConformationStateData,
states_ua: dict[UAKey, list[str]],
states_res: list[list[str]],
flexible_ua: dict[UAKey, int],
flexible_res: list[int],
) -> None:
"""Merge one group's state data into final output accumulators.
Args:
state_data: Serial conformational state data for one group.
states_ua: UA state accumulator to mutate.
states_res: Residue state accumulator to mutate.
flexible_ua: UA flexible-dihedral accumulator to mutate.
flexible_res: Residue flexible-dihedral accumulator to mutate.
Returns:
None. Mutates the provided accumulators.
"""
for key, states in state_data.states_ua_updates.items():
if key not in states_ua:
states_ua[key] = states
flexible_ua[key] = state_data.flexible_ua_updates[key]
else:
states_ua[key].extend(states)
flexible_ua[key] = max(
flexible_ua[key],
state_data.flexible_ua_updates[key],
)
states_res.append(state_data.state_res)
flexible_res.append(state_data.flex_res)
def _process_conformations_from_angles(
self,
peaks: list[Any],
angles: np.ndarray,
) -> tuple[list[str], int]:
"""Assign conformational states from a positive-angle NumPy array.
Args:
peaks: Histogram peaks by dihedral.
angles: Positive-angle array with shape ``(n_frames, n_dihedrals)``.
Returns:
Tuple of ``(states, num_flexible)``.
"""
if angles.size == 0 or angles.shape[1] == 0:
return [], 0
padded_peaks, peak_counts = self._pad_peak_values(peaks)
labels, num_flexible = assign_peak_labels_and_count_flexible(
angles,
padded_peaks,
peak_counts,
)
states = self._state_strings_from_labels(labels)
return states, int(num_flexible)
@staticmethod
def _pad_peak_values(peaks: list[Any]) -> tuple[np.ndarray, np.ndarray]:
"""Convert ragged peak lists into padded arrays for kernels.
Args:
peaks: Peak values by dihedral.
Returns:
Tuple of ``(padded_peaks, peak_counts)``.
"""
if not peaks:
return (
np.zeros((0, 1), dtype=np.float64),
np.zeros(0, dtype=np.int64),
)
max_peaks = max((len(dihedral_peaks) for dihedral_peaks in peaks), default=0)
max_peaks = max(1, max_peaks)
padded = np.zeros((len(peaks), max_peaks), dtype=np.float64)
counts = np.zeros(len(peaks), dtype=np.int64)
for dihedral_index, dihedral_peaks in enumerate(peaks):
counts[dihedral_index] = len(dihedral_peaks)
for peak_index, peak in enumerate(dihedral_peaks):
padded[dihedral_index, peak_index] = float(peak)
return padded, counts
@staticmethod
def _state_strings_from_labels(labels: np.ndarray) -> list[str]:
"""Convert integer per-frame labels into legacy state strings.
Args:
labels: Integer labels with shape ``(n_frames, n_dihedrals)``.
Returns:
Legacy state strings, one per frame.
"""
states: list[str] = []
number_frames = labels.shape[0]
num_dihedrals = labels.shape[1]
for frame_index in range(number_frames):
state = "".join(
str(int(labels[frame_index, dihedral_index]))
for dihedral_index in range(num_dihedrals)
)
if state:
states.append(state)
return states