Source code for CodeEntropy.levels.dihedrals.peak_detection
"""Conformational peak detection from dihedral angle observations.
This module contains histogram and peak-identification logic for converting
chunk-local selected-frame dihedral angle observations into global
conformational peak definitions.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any, cast
import numpy as np
from CodeEntropy.levels.dihedrals.angle_observations import (
DihedralAngleCollector,
DihedralAngleObservable,
)
from CodeEntropy.levels.dihedrals.kernels import (
histogram_counts_by_dihedral,
)
logger = logging.getLogger(__name__)
HistogramValues = dict[int, np.ndarray]
HistogramContainer = dict[int, HistogramValues | list[Any]]
[docs]
@dataclass
class DihedralPeakData:
"""Histogram peak definitions used for conformational state assignment.
Attributes:
peaks_ua: United-atom peak values by residue and dihedral index.
peaks_res: Residue-level peak values by dihedral index.
"""
peaks_ua: list[list[Any]]
peaks_res: list[Any]
[docs]
@dataclass
class DihedralHistogramData:
"""Reduced histogram counts for one conformational group.
Attributes:
num_residues: Number of residues in the representative molecule.
num_dihedrals_ua: Number of united-atom dihedrals by residue index.
num_dihedrals_res: Number of residue-level dihedrals.
hist_ua: United-atom histogram counts by residue and dihedral index.
hist_res: Residue-level histogram counts by dihedral index, or an empty
list when no residue-level histograms are present.
"""
num_residues: int
num_dihedrals_ua: list[int]
num_dihedrals_res: int
hist_ua: HistogramContainer
hist_res: HistogramValues | list[Any]
[docs]
class ConformationPeakDetector(DihedralAngleCollector):
"""Identify conformational peak definitions from dihedral observations."""
def _reduce_angle_observables_to_peak_data(
self,
observables: list[DihedralAngleObservable],
level_list: list[Any],
bin_width: float,
) -> DihedralPeakData:
"""Reduce chunk-local angle observables into global peak definitions.
Args:
observables: Chunk-local angle observables for one group.
level_list: Enabled hierarchy levels.
bin_width: Histogram bin width in degrees.
Returns:
Global peak definitions for the group.
"""
histogram_data = self._reduce_angle_observables_to_histograms(
observables=observables,
level_list=level_list,
bin_width=bin_width,
)
return self._build_peak_data_from_histograms(
histogram_data=histogram_data,
level_list=level_list,
bin_width=bin_width,
)
def _reduce_angle_observables_to_histograms(
self,
observables: list[DihedralAngleObservable],
level_list: list[Any],
bin_width: float,
) -> DihedralHistogramData:
"""Reduce chunk-local angle arrays into summed histogram counts.
Args:
observables: Chunk-local angle observables for one group.
level_list: Enabled hierarchy levels.
bin_width: Histogram bin width in degrees.
Returns:
Reduced histogram counts for the group.
"""
if not observables:
return DihedralHistogramData(
num_residues=0,
num_dihedrals_ua=[],
num_dihedrals_res=0,
hist_ua={},
hist_res=[],
)
ordered_observables = sorted(
observables,
key=lambda observable: (
observable.task.molecule_order,
observable.task.chunk_id,
),
)
number_bins = int(360 / bin_width)
first = ordered_observables[0]
num_residues = first.num_residues
num_dihedrals_ua = [0 for _ in range(num_residues)]
hist_ua: HistogramContainer = {}
hist_res: HistogramValues | list[Any] = []
num_dihedrals_res = 0
if "united_atom" in level_list:
for res_id in range(num_residues):
for observable in ordered_observables:
angles = observable.ua_angles_by_residue.get(res_id)
if angles is None or angles.shape[1] == 0:
hist_ua.setdefault(res_id, [])
continue
num_dihedrals_ua[res_id] = angles.shape[1]
counts = histogram_counts_by_dihedral(angles, number_bins)
if res_id not in hist_ua or isinstance(hist_ua[res_id], list):
hist_ua[res_id] = {}
target = cast(HistogramValues, hist_ua[res_id])
for dihedral_index in range(counts.shape[0]):
if dihedral_index not in target:
target[dihedral_index] = counts[dihedral_index].copy()
else:
target[dihedral_index] = (
target[dihedral_index] + counts[dihedral_index]
)
if "residue" in level_list:
for observable in ordered_observables:
if observable.residue_angles is None:
continue
angles = observable.residue_angles
if angles.shape[1] == 0:
continue
num_dihedrals_res = angles.shape[1]
counts = histogram_counts_by_dihedral(angles, number_bins)
if isinstance(hist_res, list):
hist_res = {}
target_res = cast(HistogramValues, hist_res)
for dihedral_index in range(counts.shape[0]):
if dihedral_index not in target_res:
target_res[dihedral_index] = counts[dihedral_index].copy()
else:
target_res[dihedral_index] = (
target_res[dihedral_index] + counts[dihedral_index]
)
return DihedralHistogramData(
num_residues=num_residues,
num_dihedrals_ua=num_dihedrals_ua,
num_dihedrals_res=num_dihedrals_res,
hist_ua=hist_ua,
hist_res=hist_res,
)
def _build_peak_data_from_histograms(
self,
histogram_data: DihedralHistogramData,
level_list: list[Any],
bin_width: float,
) -> DihedralPeakData:
"""Build peak definitions from reduced histogram counts.
Args:
histogram_data: Reduced histogram counts for one group.
level_list: Enabled hierarchy levels.
bin_width: Histogram bin width in degrees.
Returns:
Peak definitions for united-atom and residue-level states.
"""
peaks_ua: list[list[Any]] = [[] for _ in range(histogram_data.num_residues)]
peaks_res: list[Any] = []
number_bins = int(360 / bin_width)
bin_edges = np.linspace(0.0, 360.0, number_bins + 1)
bin_value = [
0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(number_bins)
]
if "united_atom" in level_list:
for res_id in range(histogram_data.num_residues):
hist_values = histogram_data.hist_ua.get(res_id)
if not hist_values:
peaks_ua[res_id] = []
continue
hist_values = cast(HistogramValues, hist_values)
residue_peaks = []
for dihedral_index in range(histogram_data.num_dihedrals_ua[res_id]):
counts = hist_values[dihedral_index]
residue_peaks.append(
self._find_histogram_peaks(
popul=counts,
bin_value=bin_value,
)
)
peaks_ua[res_id] = residue_peaks
if "residue" in level_list and histogram_data.hist_res:
hist_res = cast(HistogramValues, histogram_data.hist_res)
for dihedral_index in range(histogram_data.num_dihedrals_res):
counts = hist_res[dihedral_index]
peaks_res.append(
self._find_histogram_peaks(
popul=counts,
bin_value=bin_value,
)
)
return DihedralPeakData(peaks_ua=peaks_ua, peaks_res=peaks_res)
@staticmethod
def _find_histogram_peaks(
popul: np.ndarray[Any, Any], bin_value: list[float]
) -> list[float]:
"""Return convex turning-point peaks from a histogram.
Args:
popul: Histogram bin populations.
bin_value: Histogram bin centre values.
Returns:
List of peak positions.
"""
number_bins = len(popul)
peaks: list[float] = []
for bin_index in range(number_bins):
if popul[bin_index] == 0:
continue
left = popul[bin_index - 1]
right = popul[0] if bin_index == number_bins - 1 else popul[bin_index + 1]
if popul[bin_index] >= left and popul[bin_index] > right:
peaks.append(bin_value[bin_index])
return peaks