Source code for pinder.core.structure.surgery
from __future__ import annotations
import biotite.structure as struc
import numpy as np
from numpy.typing import NDArray
from biotite.structure.atoms import AtomArray, AtomArrayStack
from pinder.core.utils import setup_logger
from pinder.core.structure.atoms import apply_mask, get_seq_aligned_structures
from pinder.core.structure.models import ChainConfig
log = setup_logger(__name__)
[docs]
def remove_annotations(
structure: AtomArray | AtomArrayStack,
categories: list[str] = ["element", "ins_code"],
) -> AtomArray | AtomArrayStack:
if isinstance(structure, AtomArrayStack):
shape_idx = 1
else:
shape_idx = 0
for annotation in categories:
val: float | str = 0.0 if annotation == "b_factor" else ""
annotation_arr: NDArray[np.double | np.str_] = np.repeat(
val, structure.shape[shape_idx]
)
structure.set_annotation(annotation, annotation_arr)
return structure
[docs]
def fix_annotation_mismatch(
ref: AtomArray,
decoys: AtomArrayStack,
categories: list[str] = ["element", "ins_code", "b_factor"],
) -> tuple[AtomArray, AtomArrayStack]:
for annot in ref.get_annotation_categories():
ref_annot = ref.get_annotation(annot)
decoy_annot = decoys.get_annotation(annot)
if not np.array_equal(ref_annot, decoy_annot):
log.debug(f"Decoy and ref have differing {annot} categories!")
if annot not in categories:
continue
decoys = remove_annotations(decoys, categories=[annot])
ref = remove_annotations(ref, categories=[annot])
return ref, decoys
[docs]
def fix_mismatched_atoms(
native: AtomArray, decoy_stack: AtomArrayStack, max_atom_delta: int
) -> tuple[AtomArray, AtomArrayStack]:
identical = np.array_equal(decoy_stack.res_id, native.res_id)
if identical:
# Both shape and res_id elements are identical
return native, decoy_stack
log.debug("Detected mismatch between native and decoy stack!")
# Sequence based structural alignment
# Make the decoy_stack match numbering of native
log.debug("Attempting sequence-based structural alignment")
native, decoy_stack = get_seq_aligned_structures(native, decoy_stack)
native_intersect = native[struc.filter_intersection(native, decoy_stack)]
in_common = native_intersect.shape[0]
native_atoms = native.shape[0]
mismatch = native_atoms - in_common
if mismatch > max_atom_delta:
log.debug(
f"Large atom mismatch detected between native and models: {mismatch} atoms."
)
log.debug("Attempting to fix mismatch")
# In test case there are missing element annotations
native = remove_annotations(native)
decoy_stack = remove_annotations(decoy_stack)
native_intersect = native[struc.filter_intersection(native, decoy_stack)]
in_common = native_intersect.shape[0]
log.warning(
"Caution: results will only represent the atoms in common! "
f"keeping {in_common} / {native_atoms} atoms in common"
)
native = native_intersect.copy()
decoy_stack = decoy_stack[..., struc.filter_intersection(decoy_stack, native)]
return native, decoy_stack
[docs]
def set_canonical_chain_order(
structure: AtomArray | AtomArrayStack | list[AtomArray],
chains: ChainConfig,
subject: str,
) -> AtomArray | AtomArrayStack | list[AtomArray]:
# Create set of residues in interface split into receptor and ligand
# Conflict between residue numbers in different chains is handled by
# logical mask on array.chain_id and array.res_id
lig_chains = getattr(chains, f"{subject}_ligand")
rec_chains = getattr(chains, f"{subject}_receptor")
if isinstance(structure, list):
for i, arr in enumerate(structure):
R_mask = np.isin(arr.chain_id, rec_chains)
L_mask = np.isin(arr.chain_id, lig_chains)
R = arr[R_mask].copy()
L = arr[L_mask].copy()
structure[i] = R + L
return structure
else:
R_mask = np.isin(structure.chain_id, rec_chains)
L_mask = np.isin(structure.chain_id, lig_chains)
R = apply_mask(structure, R_mask)
L = apply_mask(structure, L_mask)
return R + L
[docs]
def remove_duplicate_calpha(atoms: AtomArray) -> AtomArray:
unique_mask = []
unique = set()
for at in atoms:
at_id = f"{at.chain_id}-{at.res_id}"
mask = not (at_id in unique)
if mask:
unique.add(at_id)
else:
log.warning(f"{at_id} is duplicated!")
unique_mask.append(mask)
atoms = atoms[np.array(unique_mask)].copy()
return atoms