Pinder loader#

The goal of this tutorial is to provide some hands-on examples of how one can leverage the pinder dataset in their ML workflow. Specifically, we will illustrate how you can use various utilities provided by pinder to write your own data pipeline.

While this tutorial will not go into details about how to write your own model, it will cover the basic groundwork necessary to interface with structures in pinder and the associated splits and metadata. You will of course want to implement your own featurization pipelines, data representations, etc. but the hope is that this tutorial clarifies how to access the data and make use of it in an ML framework.

Before proceeding with this tutorial section, you may find it helpful to review the existing tutorials available in pinder.

Specifcially, the tutorials covering:

Accessing and loading data for training#

In order to access the train and val splits for PINDER, please refer to the pinder documentation

Once you have downloaded the pinder dataset, either via the pinder package or directly through gsutil, you will have all of the necessary files for training.

To get a list of those systems and their split labels, refer to the pinder index.

We will start by looking at the most basic way to load items from the training and validation set: via PinderSystem objects

Recap: PinderSystem and Structure classes#

import torch

from pinder.core import get_index, PinderSystem

def get_system(system_id: str) -> PinderSystem:
    return PinderSystem(system_id)


index = get_index()
train = index[index.split == "train"].copy()
system = get_system(train.id.iloc[0])
system
    
2024-11-15 12:09:52,578 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
PinderSystem(
entry = IndexEntry(
    (
        'split',
        'train',
    ),
    (
        'id',
        '8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
    ),
    (
        'pdb_id',
        '8phr',
    ),
    (
        'cluster_id',
        'cluster_24559_24559',
    ),
    (
        'cluster_id_R',
        'cluster_24559',
    ),
    (
        'cluster_id_L',
        'cluster_24559',
    ),
    (
        'pinder_s',
        False,
    ),
    (
        'pinder_xl',
        False,
    ),
    (
        'pinder_af2',
        False,
    ),
    (
        'uniprot_R',
        'UNDEFINED',
    ),
    (
        'uniprot_L',
        'UNDEFINED',
    ),
    (
        'holo_R_pdb',
        '8phr__X4_UNDEFINED-R.pdb',
    ),
    (
        'holo_L_pdb',
        '8phr__W4_UNDEFINED-L.pdb',
    ),
    (
        'predicted_R_pdb',
        '',
    ),
    (
        'predicted_L_pdb',
        '',
    ),
    (
        'apo_R_pdb',
        '',
    ),
    (
        'apo_L_pdb',
        '',
    ),
    (
        'apo_R_pdbs',
        '',
    ),
    (
        'apo_L_pdbs',
        '',
    ),
    (
        'holo_R',
        True,
    ),
    (
        'holo_L',
        True,
    ),
    (
        'predicted_R',
        False,
    ),
    (
        'predicted_L',
        False,
    ),
    (
        'apo_R',
        False,
    ),
    (
        'apo_L',
        False,
    ),
    (
        'apo_R_quality',
        '',
    ),
    (
        'apo_L_quality',
        '',
    ),
    (
        'chain1_neff',
        10.78125,
    ),
    (
        'chain2_neff',
        11.1171875,
    ),
    (
        'chain_R',
        'X4',
    ),
    (
        'chain_L',
        'W4',
    ),
    (
        'contains_antibody',
        False,
    ),
    (
        'contains_antigen',
        False,
    ),
    (
        'contains_enzyme',
        False,
    ),
)
native=Structure(
    filepath=/home/runner/.local/share/pinder/2024-02/pdbs/8phr__X4_UNDEFINED--8phr__W4_UNDEFINED.pdb,
    uniprot_map=None,
    pinder_id='8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
    atom_array=<class 'biotite.structure.AtomArray'> with shape (2556,),
    pdb_engine='fastpdb',
)
holo_receptor=Structure(
    filepath=/home/runner/.local/share/pinder/2024-02/pdbs/8phr__X4_UNDEFINED-R.pdb,
    uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/8phr__X4_UNDEFINED-R.parquet,
    pinder_id='8phr__X4_UNDEFINED-R',
    atom_array=<class 'biotite.structure.AtomArray'> with shape (1358,),
    pdb_engine='fastpdb',
)
holo_ligand=Structure(
    filepath=/home/runner/.local/share/pinder/2024-02/pdbs/8phr__W4_UNDEFINED-L.pdb,
    uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/8phr__W4_UNDEFINED-L.parquet,
    pinder_id='8phr__W4_UNDEFINED-L',
    atom_array=<class 'biotite.structure.AtomArray'> with shape (1198,),
    pdb_engine='fastpdb',
)
apo_receptor=None
apo_ligand=None
pred_receptor=None
pred_ligand=None
)

Notice the printed PinderSystem object has the following properties:

  • native - the ground-truth dimer complex

  • holo_receptor - the receptor chain (monomer) from the ground-truth complex

  • holo_ligand - the ligand chain (monomer) from the ground-truth complex

  • apo_receptor - the canonical apo chain (monomer) paired to the receptor chain

  • apo_ligand - the canonical apo chain (monomer) paired to the ligand chain

  • pred_receptor - the AlphaFold2 predicted monomer paired to the receptor chain

  • pred_ligand - the AlphaFold2 predicted monomer paired to the ligand chain

These properties are pointers to Structure objects. The Structure object provides the most direct mode of access to structures and associated properties.

Note: not all systems have an apo and/or predicted structure for all chains of the ground-truth dimer complex!

As was the case in the example above, when the alternative monomers are not available, the property will have a value of None.

You can determine which systems have which alternative monomer pairings a priori by looking at the boolean columns in the index apo_R and apo_L for the apo receptor and ligand, and predicted_R and predicted_L for the predicted receptor and ligand, respectively.

For instance, we can load a different system that does have apo receptor and ligand as such:

apo_system = get_system(train.query('apo_R and apo_L').id.iloc[0])
receptor = apo_system.apo_receptor
ligand = apo_system.apo_ligand 

receptor, ligand
2024-11-15 12:09:53,554 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=15
(Structure(
     filepath=/home/runner/.local/share/pinder/2024-02/pdbs/3wdb__A1_P9WPC9.pdb,
     uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/3wdb__A1_P9WPC9.parquet,
     pinder_id='3wdb__A1_P9WPC9',
     atom_array=<class 'biotite.structure.AtomArray'> with shape (1144,),
     pdb_engine='fastpdb',
 ),
 Structure(
     filepath=/home/runner/.local/share/pinder/2024-02/pdbs/6ucr__A1_P9WPC9.pdb,
     uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/6ucr__A1_P9WPC9.parquet,
     pinder_id='6ucr__A1_P9WPC9',
     atom_array=<class 'biotite.structure.AtomArray'> with shape (1193,),
     pdb_engine='fastpdb',
 ))

We can now access e.g. the sequence and the coordinates of the structures via the Structure objects:

receptor.sequence
'PLGSMFERFTDRARRVVVLAQEEARMLNHNYIGTEHILLGLIHEGEGVAAKSLESLGISLEGVRSQVEEIIGQGQQAPSGHIPFTPRAKKVLELSLREALQLGHNYIGTEHILLGLIREGEGVAAQVLVKLGAELTRVRQQVIQLLSGY'
receptor.coords[0:5]
array([[-12.982, -17.271, -11.271],
       [-14.36 , -17.069, -11.749],
       [-15.261, -16.373, -10.703],
       [-15.461, -15.161, -10.801],
       [-14.842, -18.494, -12.077]], dtype=float32)

We can always access the underyling biotite AtomArray via the Structure.atom_array property:

receptor.atom_array[0:5]
array([
	Atom(np.array([-12.982, -17.271, -11.271], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="N", element="N", b_factor=0.0),
	Atom(np.array([-14.36 , -17.069, -11.749], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="CA", element="C", b_factor=0.0),
	Atom(np.array([-15.261, -16.373, -10.703], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="C", element="C", b_factor=0.0),
	Atom(np.array([-15.461, -15.161, -10.801], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="O", element="O", b_factor=0.0),
	Atom(np.array([-14.842, -18.494, -12.077], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="CB", element="C", b_factor=0.0)
])

For a more comprehensive overview of all of the Structure class properties, refer to the pinder system tutorial.

Using the PinderLoader to load, filter and transform systems#

While the PinderSystem object provides a self-contained access to structures associated with a dimer system, the PinderLoader provides a base abstraction for how to iterate over systems, apply optional filters and/or transforms, and return the systems as an iterator. This construct is covered in a different tutorial tutorial.

Using the PinderLoader is not necessary to load systems in your own framework. It is simply one of the provided mechanisms if you find it useful.

Pinder loader brings together filters, transforms and writers to create a generic PinderSystem iterator. It takes either a split name or a list of system IDs as input and can be used to sample alternative monomers to form dimer complexes to serve as e.g. features.

Loading a specific split#

Note: only the test dataset has a subset defined (pinder_s, pinder_xl, pinder_af2)

For train and val, you could just do:

train_loader = PinderLoader(split="train")
val_loader = PinderLoader(split="val")
import torch
from pinder.core import PinderLoader
from pinder.core.loader import filters

base_filters = [
    filters.FilterByMissingHolo(),
    filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
    filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
    filters.FilterSubByAtomTypes(min_atom_types=4),
    filters.FilterByHoloOverlap(min_overlap=5),
    filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
    filters.FilterSubRmsds(rmsd_cutoff=7.5),
    filters.FilterDetachedSub(radius=12, max_components=2),
]

loader = PinderLoader(
    split="test", 
    subset="pinder_af2",
    monomer_priority="holo",
    base_filters = base_filters,
    sub_filters = sub_filters
)

loader
PinderLoader(split=test, monomers=holo, systems=180)
len(loader)
180
data = loader[0]
print(f"Data is a {type(data)}")
system, feature_complex, target_complex = data
type(system), type(feature_complex), type(target_complex)
2024-11-15 12:09:58,528 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
Data is a <class 'tuple'>
(pinder.core.index.system.PinderSystem,
 pinder.core.loader.structure.Structure,
 pinder.core.loader.structure.Structure)
# You can also use it as an iterator
from tqdm import tqdm
loaded_ids = []
for (system, feature_complex, target_complex) in tqdm(loader):
    loaded_ids.append(system.entry.id)
  0%|          | 0/180 [00:00<?, ?it/s]
  1%|          | 1/180 [00:00<00:50,  3.58it/s]
2024-11-15 12:10:00,180 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
  1%|          | 2/180 [00:01<02:08,  1.39it/s]
2024-11-15 12:10:01,199 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
  2%|▏         | 3/180 [00:02<02:40,  1.10it/s]
2024-11-15 12:10:02,333 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  2%|▏         | 4/180 [00:04<03:46,  1.29s/it]
2024-11-15 12:10:04,197 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  3%|▎         | 5/180 [00:05<03:52,  1.33s/it]
2024-11-15 12:10:05,596 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=12
  3%|▎         | 6/180 [00:08<04:52,  1.68s/it]
2024-11-15 12:10:07,964 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  4%|▍         | 7/180 [00:09<04:24,  1.53s/it]
2024-11-15 12:10:09,188 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=11
  4%|▍         | 8/180 [00:11<04:57,  1.73s/it]
  5%|▌         | 9/180 [00:11<03:40,  1.29s/it]
  6%|▌         | 10/180 [00:12<03:03,  1.08s/it]
2024-11-15 12:10:12,275 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  6%|▌         | 11/180 [00:14<03:55,  1.39s/it]
  7%|▋         | 12/180 [00:14<02:57,  1.05s/it]
2024-11-15 12:10:14,658 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  7%|▋         | 13/180 [00:16<03:16,  1.18s/it]
  8%|▊         | 14/180 [00:16<02:31,  1.09it/s]
2024-11-15 12:10:16,424 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  8%|▊         | 15/180 [00:17<02:51,  1.04s/it]
2024-11-15 12:10:17,751 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
  9%|▉         | 16/180 [00:19<03:13,  1.18s/it]
2024-11-15 12:10:19,263 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
  9%|▉         | 17/180 [00:21<03:46,  1.39s/it]
2024-11-15 12:10:21,146 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 10%|█         | 18/180 [00:22<03:35,  1.33s/it]
2024-11-15 12:10:22,328 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 11%|█         | 19/180 [00:24<03:48,  1.42s/it]
2024-11-15 12:10:23,963 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=9
 11%|█         | 20/180 [00:26<04:15,  1.60s/it]
2024-11-15 12:10:25,978 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 12%|█▏        | 21/180 [00:27<04:22,  1.65s/it]
2024-11-15 12:10:27,743 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 12%|█▏        | 22/180 [00:29<04:42,  1.79s/it]
2024-11-15 12:10:29,850 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 13%|█▎        | 23/180 [00:31<04:27,  1.70s/it]
2024-11-15 12:10:31,356 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=9
 13%|█▎        | 24/180 [00:32<04:05,  1.57s/it]
2024-11-15 12:10:32,625 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 14%|█▍        | 25/180 [00:33<03:29,  1.35s/it]
2024-11-15 12:10:33,462 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 14%|█▍        | 26/180 [00:34<03:07,  1.21s/it]
2024-11-15 12:10:34,357 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 15%|█▌        | 27/180 [00:35<02:49,  1.11s/it]
2024-11-15 12:10:35,220 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 16%|█▌        | 28/180 [00:36<02:54,  1.15s/it]
2024-11-15 12:10:36,466 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 16%|█▌        | 29/180 [00:37<02:35,  1.03s/it]
2024-11-15 12:10:37,220 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 17%|█▋        | 30/180 [00:38<02:20,  1.06it/s]
2024-11-15 12:10:37,950 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 17%|█▋        | 31/180 [00:39<02:41,  1.08s/it]
 18%|█▊        | 32/180 [00:39<02:09,  1.15it/s]
2024-11-15 12:10:39,741 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 18%|█▊        | 33/180 [00:40<02:09,  1.14it/s]
2024-11-15 12:10:40,635 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 19%|█▉        | 34/180 [00:41<02:12,  1.10it/s]
2024-11-15 12:10:41,629 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 19%|█▉        | 35/180 [00:42<02:02,  1.19it/s]
2024-11-15 12:10:42,303 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 20%|██        | 36/180 [00:43<02:21,  1.02it/s]
2024-11-15 12:10:43,603 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 21%|██        | 37/180 [00:44<02:31,  1.06s/it]
2024-11-15 12:10:44,850 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 21%|██        | 38/180 [00:48<03:57,  1.67s/it]
 22%|██▏       | 39/180 [00:50<04:08,  1.76s/it]
 22%|██▏       | 40/180 [00:50<03:13,  1.38s/it]
2024-11-15 12:10:50,417 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 23%|██▎       | 41/180 [00:52<03:33,  1.54s/it]
2024-11-15 12:10:52,319 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 23%|██▎       | 42/180 [00:53<03:19,  1.44s/it]
2024-11-15 12:10:53,544 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=11
 24%|██▍       | 43/180 [00:55<03:21,  1.47s/it]
2024-11-15 12:10:55,079 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 24%|██▍       | 44/180 [00:56<03:30,  1.55s/it]
2024-11-15 12:10:56,819 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 25%|██▌       | 45/180 [00:58<03:27,  1.54s/it]
 26%|██▌       | 46/180 [00:58<02:40,  1.20s/it]
2024-11-15 12:10:58,720 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 26%|██▌       | 47/180 [01:00<02:40,  1.20s/it]
2024-11-15 12:10:59,948 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 27%|██▋       | 48/180 [01:00<02:25,  1.10s/it]
2024-11-15 12:11:00,807 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 27%|██▋       | 49/180 [01:01<02:05,  1.04it/s]
2024-11-15 12:11:01,426 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 28%|██▊       | 50/180 [01:02<01:59,  1.09it/s]
 28%|██▊       | 51/180 [01:02<01:37,  1.33it/s]
2024-11-15 12:11:02,622 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 29%|██▉       | 52/180 [01:03<01:49,  1.17it/s]
 29%|██▉       | 53/180 [01:04<01:45,  1.20it/s]
2024-11-15 12:11:04,499 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 30%|███       | 54/180 [01:05<01:40,  1.25it/s]
2024-11-15 12:11:05,216 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 31%|███       | 55/180 [01:06<01:36,  1.29it/s]
2024-11-15 12:11:05,933 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 31%|███       | 56/180 [01:07<02:07,  1.03s/it]
2024-11-15 12:11:07,555 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 32%|███▏      | 57/180 [01:08<02:04,  1.01s/it]
2024-11-15 12:11:08,517 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 32%|███▏      | 58/180 [01:10<02:33,  1.25s/it]
2024-11-15 12:11:10,346 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 33%|███▎      | 59/180 [01:11<02:28,  1.23s/it]
2024-11-15 12:11:11,524 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 33%|███▎      | 60/180 [01:12<02:14,  1.12s/it]
2024-11-15 12:11:12,384 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 34%|███▍      | 61/180 [01:13<01:55,  1.03it/s]
 34%|███▍      | 62/180 [01:13<01:49,  1.08it/s]
2024-11-15 12:11:13,821 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 35%|███▌      | 63/180 [01:14<01:51,  1.05it/s]
2024-11-15 12:11:14,838 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 36%|███▌      | 64/180 [01:15<01:50,  1.05it/s]
2024-11-15 12:11:15,783 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 36%|███▌      | 65/180 [01:18<03:02,  1.59s/it]
2024-11-15 12:11:18,871 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 37%|███▋      | 66/180 [01:20<03:00,  1.58s/it]
2024-11-15 12:11:20,438 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 37%|███▋      | 67/180 [01:21<02:50,  1.51s/it]
2024-11-15 12:11:21,756 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 38%|███▊      | 68/180 [01:22<02:24,  1.29s/it]
2024-11-15 12:11:22,540 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 38%|███▊      | 69/180 [01:23<02:01,  1.10s/it]
2024-11-15 12:11:23,193 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 39%|███▉      | 70/180 [01:24<02:02,  1.11s/it]
 39%|███▉      | 71/180 [01:25<01:45,  1.03it/s]
2024-11-15 12:11:24,971 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 40%|████      | 72/180 [01:26<01:50,  1.03s/it]
2024-11-15 12:11:26,130 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 41%|████      | 73/180 [01:27<01:43,  1.03it/s]
2024-11-15 12:11:26,973 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 41%|████      | 74/180 [01:29<02:21,  1.34s/it]
2024-11-15 12:11:29,167 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 42%|████▏     | 75/180 [01:30<02:21,  1.35s/it]
2024-11-15 12:11:30,541 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 42%|████▏     | 76/180 [01:31<02:16,  1.31s/it]
2024-11-15 12:11:31,775 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 43%|████▎     | 77/180 [01:32<02:03,  1.20s/it]
2024-11-15 12:11:32,718 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 43%|████▎     | 78/180 [01:33<01:56,  1.14s/it]
2024-11-15 12:11:33,721 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 44%|████▍     | 79/180 [01:35<02:08,  1.27s/it]
2024-11-15 12:11:35,285 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=9
 44%|████▍     | 80/180 [01:37<02:33,  1.54s/it]
 45%|████▌     | 81/180 [01:38<02:12,  1.34s/it]
2024-11-15 12:11:38,318 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 46%|████▌     | 82/180 [01:39<02:05,  1.28s/it]
2024-11-15 12:11:39,467 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=6
 46%|████▌     | 83/180 [01:40<01:47,  1.11s/it]
2024-11-15 12:11:40,185 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 47%|████▋     | 84/180 [01:41<01:49,  1.14s/it]
2024-11-15 12:11:41,376 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 47%|████▋     | 85/180 [01:42<01:44,  1.10s/it]
2024-11-15 12:11:42,386 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 48%|████▊     | 86/180 [01:43<01:31,  1.02it/s]
2024-11-15 12:11:43,087 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 48%|████▊     | 87/180 [01:43<01:20,  1.16it/s]
 49%|████▉     | 88/180 [01:44<01:09,  1.32it/s]
 49%|████▉     | 89/180 [01:44<00:57,  1.57it/s]
 50%|█████     | 90/180 [01:45<00:50,  1.77it/s]
2024-11-15 12:11:44,937 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 51%|█████     | 91/180 [01:45<00:51,  1.72it/s]
 51%|█████     | 92/180 [01:46<00:49,  1.78it/s]
2024-11-15 12:11:46,073 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 52%|█████▏    | 93/180 [01:46<00:53,  1.62it/s]
2024-11-15 12:11:46,819 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 52%|█████▏    | 94/180 [01:48<01:05,  1.32it/s]
2024-11-15 12:11:47,916 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 53%|█████▎    | 95/180 [01:48<01:08,  1.25it/s]
2024-11-15 12:11:48,816 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 53%|█████▎    | 96/180 [01:49<01:06,  1.26it/s]
2024-11-15 12:11:49,582 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 54%|█████▍    | 97/180 [01:50<01:10,  1.18it/s]
2024-11-15 12:11:50,550 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 54%|█████▍    | 98/180 [01:51<01:07,  1.21it/s]
2024-11-15 12:11:51,334 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 55%|█████▌    | 99/180 [01:52<01:08,  1.18it/s]
 56%|█████▌    | 100/180 [01:52<00:54,  1.46it/s]
 56%|█████▌    | 101/180 [01:53<00:47,  1.67it/s]
2024-11-15 12:11:52,933 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 57%|█████▋    | 102/180 [01:54<00:55,  1.40it/s]
2024-11-15 12:11:53,920 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 57%|█████▋    | 103/180 [01:55<01:01,  1.25it/s]
2024-11-15 12:11:54,917 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 58%|█████▊    | 104/180 [01:55<00:55,  1.36it/s]
2024-11-15 12:11:55,495 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 58%|█████▊    | 105/180 [01:56<01:08,  1.10it/s]
2024-11-15 12:11:56,821 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 59%|█████▉    | 106/180 [01:58<01:15,  1.02s/it]
2024-11-15 12:11:58,111 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=6
 59%|█████▉    | 107/180 [01:58<01:09,  1.06it/s]
2024-11-15 12:11:58,877 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 60%|██████    | 108/180 [02:00<01:22,  1.15s/it]
2024-11-15 12:12:00,486 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 61%|██████    | 109/180 [02:01<01:18,  1.10s/it]
2024-11-15 12:12:01,492 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 61%|██████    | 110/180 [02:02<01:09,  1.01it/s]
 62%|██████▏   | 111/180 [02:02<00:55,  1.25it/s]
2024-11-15 12:12:02,572 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 62%|██████▏   | 112/180 [02:03<00:59,  1.14it/s]
2024-11-15 12:12:03,637 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 63%|██████▎   | 113/180 [02:04<01:01,  1.08it/s]
2024-11-15 12:12:04,661 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 63%|██████▎   | 114/180 [02:05<01:02,  1.05it/s]
2024-11-15 12:12:05,683 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 64%|██████▍   | 115/180 [02:06<01:02,  1.04it/s]
 64%|██████▍   | 116/180 [02:07<00:51,  1.23it/s]
2024-11-15 12:12:07,124 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 65%|██████▌   | 117/180 [02:07<00:46,  1.35it/s]
2024-11-15 12:12:07,700 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 66%|██████▌   | 118/180 [02:08<00:48,  1.27it/s]
2024-11-15 12:12:08,604 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=6
 66%|██████▌   | 119/180 [02:09<00:49,  1.22it/s]
2024-11-15 12:12:09,490 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 67%|██████▋   | 120/180 [02:10<00:53,  1.13it/s]
 67%|██████▋   | 121/180 [02:11<00:43,  1.35it/s]
2024-11-15 12:12:10,924 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 68%|██████▊   | 122/180 [02:11<00:42,  1.36it/s]
2024-11-15 12:12:11,655 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 68%|██████▊   | 123/180 [02:12<00:39,  1.43it/s]
2024-11-15 12:12:12,266 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 69%|██████▉   | 124/180 [02:13<00:39,  1.42it/s]
2024-11-15 12:12:12,980 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 69%|██████▉   | 125/180 [02:14<00:44,  1.24it/s]
2024-11-15 12:12:14,040 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 70%|███████   | 126/180 [02:15<00:51,  1.04it/s]
2024-11-15 12:12:15,338 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 71%|███████   | 127/180 [02:16<00:45,  1.18it/s]
2024-11-15 12:12:15,944 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 71%|███████   | 128/180 [02:17<00:48,  1.06it/s]
 72%|███████▏  | 129/180 [02:18<00:50,  1.02it/s]
 72%|███████▏  | 130/180 [02:18<00:39,  1.26it/s]
2024-11-15 12:12:18,519 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 73%|███████▎  | 131/180 [02:19<00:43,  1.14it/s]
2024-11-15 12:12:19,602 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 73%|███████▎  | 132/180 [02:21<00:48,  1.01s/it]
2024-11-15 12:12:20,928 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 74%|███████▍  | 133/180 [02:21<00:43,  1.08it/s]
2024-11-15 12:12:21,658 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 74%|███████▍  | 134/180 [02:22<00:41,  1.11it/s]
2024-11-15 12:12:22,500 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 75%|███████▌  | 135/180 [02:23<00:42,  1.05it/s]
2024-11-15 12:12:23,569 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 76%|███████▌  | 136/180 [02:24<00:45,  1.04s/it]
 76%|███████▌  | 137/180 [02:25<00:37,  1.15it/s]
 77%|███████▋  | 138/180 [02:25<00:30,  1.39it/s]
2024-11-15 12:12:25,653 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 77%|███████▋  | 139/180 [02:26<00:28,  1.42it/s]
2024-11-15 12:12:26,323 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 78%|███████▊  | 140/180 [02:27<00:34,  1.17it/s]
2024-11-15 12:12:27,533 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 78%|███████▊  | 141/180 [02:29<00:40,  1.04s/it]
2024-11-15 12:12:29,004 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 79%|███████▉  | 142/180 [02:30<00:38,  1.02s/it]
2024-11-15 12:12:29,961 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 79%|███████▉  | 143/180 [02:31<00:38,  1.03s/it]
2024-11-15 12:12:31,040 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 80%|████████  | 144/180 [02:31<00:35,  1.03it/s]
2024-11-15 12:12:31,874 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 81%|████████  | 145/180 [02:32<00:33,  1.05it/s]
 81%|████████  | 146/180 [02:33<00:25,  1.31it/s]
2024-11-15 12:12:33,092 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 82%|████████▏ | 147/180 [02:33<00:25,  1.31it/s]
 82%|████████▏ | 148/180 [02:34<00:20,  1.55it/s]
2024-11-15 12:12:34,226 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 83%|████████▎ | 149/180 [02:35<00:22,  1.35it/s]
2024-11-15 12:12:35,181 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 83%|████████▎ | 150/180 [02:36<00:23,  1.29it/s]
2024-11-15 12:12:36,040 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 84%|████████▍ | 151/180 [02:36<00:21,  1.36it/s]
2024-11-15 12:12:36,686 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 84%|████████▍ | 152/180 [02:37<00:22,  1.27it/s]
2024-11-15 12:12:37,595 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 85%|████████▌ | 153/180 [02:38<00:20,  1.29it/s]
2024-11-15 12:12:38,478 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=384
 86%|████████▌ | 154/180 [03:06<03:50,  8.85s/it]
2024-11-15 12:13:06,029 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 86%|████████▌ | 155/180 [03:07<02:46,  6.66s/it]
2024-11-15 12:13:07,581 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 87%|████████▋ | 156/180 [03:08<01:58,  4.95s/it]
2024-11-15 12:13:08,538 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 87%|████████▋ | 157/180 [03:09<01:26,  3.75s/it]
2024-11-15 12:13:09,493 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 88%|████████▊ | 158/180 [03:10<01:01,  2.82s/it]
2024-11-15 12:13:10,125 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 88%|████████▊ | 159/180 [03:10<00:45,  2.16s/it]
2024-11-15 12:13:10,764 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 89%|████████▉ | 160/180 [03:12<00:39,  1.98s/it]
2024-11-15 12:13:12,335 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 89%|████████▉ | 161/180 [03:13<00:33,  1.76s/it]
2024-11-15 12:13:13,573 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 90%|█████████ | 162/180 [03:14<00:26,  1.45s/it]
2024-11-15 12:13:14,306 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 91%|█████████ | 163/180 [03:15<00:23,  1.35s/it]
2024-11-15 12:13:15,427 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 91%|█████████ | 164/180 [03:16<00:19,  1.19s/it]
 92%|█████████▏| 165/180 [03:16<00:13,  1.08it/s]
2024-11-15 12:13:16,557 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 92%|█████████▏| 166/180 [03:17<00:12,  1.15it/s]
2024-11-15 12:13:17,288 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 93%|█████████▎| 167/180 [03:18<00:10,  1.19it/s]
2024-11-15 12:13:18,067 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 93%|█████████▎| 168/180 [03:19<00:10,  1.15it/s]
2024-11-15 12:13:18,998 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 94%|█████████▍| 169/180 [03:20<00:09,  1.12it/s]
2024-11-15 12:13:19,948 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 94%|█████████▍| 170/180 [03:20<00:07,  1.26it/s]
2024-11-15 12:13:20,518 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 95%|█████████▌| 171/180 [03:21<00:07,  1.23it/s]
 96%|█████████▌| 172/180 [03:21<00:05,  1.50it/s]
2024-11-15 12:13:21,690 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 96%|█████████▌| 173/180 [03:22<00:05,  1.33it/s]
2024-11-15 12:13:22,638 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 97%|█████████▋| 174/180 [03:24<00:05,  1.02it/s]
2024-11-15 12:13:24,158 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 97%|█████████▋| 175/180 [03:25<00:04,  1.07it/s]
2024-11-15 12:13:24,981 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
 98%|█████████▊| 176/180 [03:26<00:03,  1.01it/s]
2024-11-15 12:13:26,092 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 98%|█████████▊| 177/180 [03:27<00:03,  1.05s/it]
2024-11-15 12:13:27,285 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 99%|█████████▉| 178/180 [03:28<00:02,  1.10s/it]
2024-11-15 12:13:28,511 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
 99%|█████████▉| 179/180 [03:29<00:01,  1.00s/it]
2024-11-15 12:13:29,282 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
100%|██████████| 180/180 [03:30<00:00,  1.06it/s]
100%|██████████| 180/180 [03:30<00:00,  1.17s/it]

Loading a specific list of systems#

systems = [
    "1df0__A1_Q07009--1df0__B1_Q64537",
    "117e__A1_P00817--117e__B1_P00817",
]
loader = PinderLoader(
    ids=systems,
    monomer_priority="holo",
    base_filters = base_filters,
    sub_filters = sub_filters
)
passing_ids = []
for item in loader:
    passing_ids.append(item[0].entry.id)

systems_removed_by_filters = set(systems) - set(passing_ids)
systems_removed_by_filters
2024-11-15 12:13:30,689 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
set()
len(systems) == len(passing_ids)
True

Optional Pinder writer#

Without defining a writer for the PinderLoader, the loaded systems are available as a tuple of (PinderSystem, Structure, Structure) objects, containing the original PinderSystem and the sampled feature and target complexes, respectively.

If you want to explicitly write the (potentially transformed) structure objects to a custom location or in a custom format (e.g. PDB, pickle, etc.), you can implement a subclass of PinderWriterBase.

The default writer implements writing to PDB files (leveraging the Structure.to_pdb method on the structure objects).

from pinder.core.loader.writer import PinderDefaultWriter

from pathlib import Path
from tempfile import TemporaryDirectory

with TemporaryDirectory() as tmp_dir:
    temp_dir = Path(tmp_dir)
    loader = PinderLoader(
        ids=systems,
        monomer_priority="pred",
        writer=PinderDefaultWriter(temp_dir)
    )
    assert set(loader.index.id) == set(systems)
    for i, r in loader.index.iterrows():
        loaded = loader[i]
        pinder_id = r.id
        system_dir = loader.writer.output_path / pinder_id
        assert system_dir.is_dir()
        print(list(system_dir.glob("af_*.pdb")))
[PosixPath('/tmp/tmphvp65592/117e__A1_P00817--117e__B1_P00817/af__P00817.pdb')]
[PosixPath('/tmp/tmphvp65592/1df0__A1_Q07009--1df0__B1_Q64537/af__Q07009.pdb'), PosixPath('/tmp/tmphvp65592/1df0__A1_Q07009--1df0__B1_Q64537/af__Q64537.pdb')]

Constructing torch datasets and dataloaders from pinder systems#

The remaining sections of this tutorial will be for those interested specifically in torch datasets and dataloaders.

Specifically, we will show how to:

  • Implement a PyTorch Dataset to interface with pinder data

  • Include apo and predicted monomers in the data pipeline, with an option to target specific monomer types or randomly sample from the available types

  • Leverage PinderSystem and its associated methods to crop apo/predicted monomers to match the ground-truth holo monomers

  • Write filters and transforms that operate on Structure objects

  • Integrate annotations in data filtering and featurization

  • Create example features to use for training (you will of course choose your own features)

  • Incorporate diversity sampling in the data loader

The pinder.core.loader.dataset module provides two example implementations of how to integrate the pinder dataset into a torch-based machine learning pipeline.

  1. PinderDataset: A map-style torch.utils.data.Dataset that can be used with torch DataLoader’s.

  2. PPIDataset: A torch_geometric.data.Dataset that can be used with torch-geometric DataLoader’s. This dataset is designed to be used with the torch_geometric package.

Together, the two datasets provide an example implementation of how to abstract away the complexity of loading and processing multiple structures associated with each PinderSystem by leveraging the following utilities from pinder:

  • pinder.core.PinderLoader

  • pinder.core.loader.filters

  • pinder.core.loader.transforms

The examples cover two different batch data item structures to illustrate two different use-cases:

  • PinderDataset: A batch of (target_complex, feature_complex) pairs, where target_complex and feature_complex are torch.Tensor objects representing the atomic coordinates and atom types of the holo and sampled (decoy, holo/apo/pred) complexes, respectively.

  • PPIDataset: A batch of PairedPDB objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via torch_geometric.data.HeteroData, holding multiple node and/or edge types in disjunct storage objects.

The remaining sections will be split into:

  1. Using the PinderDataset torch dataset

  2. Using the PPIDataset torch-geometric dataset

  3. How you could implement your own dataset & dataloader

PinderDataset (torch Dataset)#

The PinderDataset is an example implementation of a torch.utils.data.Dataset that represents its data items as a dict containing the following key, value pairs:

  • target_complex: The ground-truth holo dimer, represented with a set of default properties encoded as Tensor’s

  • feature_complex: The sampled dimer complex, representing “features”, also represented with a set of default properties encoded as Tensor’s

  • id: The pinder ID for the selected system

  • target_id: The IDs of the receptor and ligand holo monomers, concatenated into a single ID string

  • sample_id: The IDs of the sampled receptor and ligand holo monomers, concatenated into a single ID string. This can be useful for debugging purposes or generally tracking which specific monomers are selected when targeting alternative monomers (more on this shortly)

Each of the target_complex and feature_complex values are dictionaries with structural properties encoded by the pinder.core.loader.geodata.structure2tensor function by default:

  • atom_coordinates

  • atom_types

  • chain_ids

  • residue_coordinates

  • residue_types

  • residue_ids

Note: You can choose to use a different representation by overriding the default values of transform and target_transform. The default transform is the structure2tensor_transform defined in pinder.core.loader.dataset. It simply takes a Structure object and returns a dictionary with string keys and tensor values.

It leverages the PinderLoader to apply optional filters and/or transforms, provide an interface for sampling alternative monomers, and exposes transform and target_transform arguments used by the torch Dataset API.

For more details on the torch Dataset APIs, please refer to the tutorials.

from pinder.core.loader import filters, transforms
from pinder.core.loader.dataset import PinderDataset, structure2tensor_transform

base_filters = [
    filters.FilterByMissingHolo(),
    filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
    filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
    filters.FilterSubByAtomTypes(min_atom_types=4),
    filters.FilterByHoloOverlap(min_overlap=5),
    filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
    filters.FilterSubRmsds(rmsd_cutoff=7.5),
    filters.FilterDetachedSub(radius=12, max_components=2),
]
# We can include Structure-level transforms (and filters) which will operate on the target and/or feature complexes
structure_transforms = [
    transforms.SelectAtomTypes(atom_types=["CA", "N", "C", "O"])
]
train_dataset = PinderDataset(
    split="train", 
    # We can leverage holo, apo, pred, random and random_mixed monomer sampling strategies
    monomer_priority="random_mixed",
    base_filters = base_filters,
    sub_filters = sub_filters,
    # Apply to the target (ground-truth) complex
    structure_transforms_target=structure_transforms,
    # Apply to the feature complex
    structure_transforms_feature=structure_transforms,
    # This is the default transform if not specified
    transform=structure2tensor_transform,
    target_transform=structure2tensor_transform,
)
assert len(train_dataset) == len(get_index().query('split == "train"'))

train_dataset
<pinder.core.loader.dataset.PinderDataset at 0x7f697dd23ac0>

Sampling alternative monomers#

The monomer_priority argument can be used to target different mixes of bound and unbound monomers to use for creating the decoy/feature complex.

The allowed values for monomer_priority are “apo”, “holo”, “pred”, “random” or “random_mixed”.

When monomer_priority is set to one of the available monomer types (holo, apo, pred), the same monomer type will be selected for both receptor and ligand.

When the monomer priority is “random”, a random monomer type will be selected from the set of monomer types available for both the receptor and ligand. This option ensures the same type of monomer is used for the receptor and ligand.

When the monomer priority is “random_mixed”, a random monomer type will be selected for each of receptor and ligand, separately.

Enabling the fallback_to_holo option (default) will enable silent fallback to holo when the monomer_priority is set to one of apo or pred, but the corresponding monomer is not available for the dimer.

This is useful when only one of receptor or ligand has an unbound monomer, but you wish to include apo or predicted structures in your workflow.

If fallback_to_holo is disabled, an error will be raised when the monomer_priority is set to one of apo or pred, but the corresponding monomer is not available for the dimer.

By default, when apo monomers are selected, the “canonical” apo monomer is used. Although a single canonical apo monomer should be used for eval, pinder provides multiple apo monomers paired to a single holo monomer (when available). In order to include these non-canonical/alternative monomers, you can specify use_canonical_apo=False when constructing the PinderLoader or PinderDataset objects.

data_item = train_dataset[0]
data_item
{'target_complex': {'atom_types': tensor([[0.],
          [1.],
          [2.],
          ...,
          [1.],
          [2.],
          [3.]]),
  'element_types': tensor([[3.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [2.]]),
  'residue_types': tensor([[16.],
          [16.],
          [16.],
          ...,
          [ 0.],
          [ 0.],
          [ 0.]]),
  'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],
          [132.6810, 428.2520, 163.1550],
          [133.5150, 428.6750, 161.9500],
          ...,
          [177.7620, 463.8650, 166.9020],
          [177.4130, 465.0800, 167.7550],
          [176.8000, 464.9490, 168.8150]]),
  'residue_coordinates': tensor([[132.6810, 428.2520, 163.1550],
          [133.5560, 429.2490, 159.5910],
          [133.8750, 432.8980, 160.6290],
          [136.1110, 431.8050, 163.5130],
          [138.3420, 429.8920, 161.0830],
          [138.5230, 432.9090, 158.7600],
          [139.4710, 435.1730, 161.6750],
          [142.1200, 432.6310, 162.6740],
          [143.5890, 432.7500, 159.1600],
          [143.6190, 436.5600, 159.1540],
          [145.2600, 436.8040, 162.5830],
          [147.8380, 434.1320, 161.7330],
          [148.7390, 435.8920, 158.4790],
          [149.0640, 439.2630, 160.2240],
          [151.2220, 437.7670, 162.9840],
          [153.4710, 435.8030, 160.6120],
          [153.9420, 438.9180, 158.4710],
          [156.1140, 440.2660, 161.3150],
          [158.3330, 437.1860, 161.4250],
          [161.8090, 436.4120, 160.1100],
          [161.1120, 432.8370, 158.9350],
          [157.4020, 432.0710, 159.3930],
          [156.5290, 428.3870, 159.0310],
          [154.0400, 428.1160, 156.1500],
          [153.3520, 424.6590, 154.7510],
          [152.8370, 424.8210, 150.9870],
          [150.9700, 422.4160, 148.7320],
          [153.5900, 422.6860, 145.9670],
          [157.3050, 423.3410, 145.6600],
          [158.4860, 426.6730, 144.2700],
          [160.7730, 427.0880, 141.2590],
          [163.9630, 429.0530, 141.9600],
          [166.4230, 427.4330, 139.5550],
          [167.6170, 430.8330, 138.3250],
          [170.5180, 431.9210, 140.5320],
          [169.7800, 435.6680, 140.3160],
          [167.6790, 437.8790, 142.6110],
          [167.3350, 435.1910, 145.2890],
          [168.9090, 437.3000, 148.0580],
          [166.1860, 439.9440, 148.5750],
          [164.1010, 438.1300, 151.1890],
          [163.0190, 438.8690, 154.7450],
          [164.6750, 435.6560, 155.9860],
          [168.1170, 436.6550, 154.6330],
          [169.8040, 437.0930, 158.0010],
          [173.2680, 438.4850, 158.5750],
          [174.9270, 440.8490, 156.1460],
          [175.9170, 440.8010, 152.5000],
          [179.6060, 441.0910, 153.4330],
          [180.8000, 439.9830, 156.8710],
          [184.2540, 439.9980, 158.4580],
          [185.5840, 437.2970, 160.7670],
          [186.4140, 439.9000, 163.4480],
          [182.9430, 441.4570, 163.7500],
          [181.4930, 442.0910, 167.2100],
          [178.0780, 440.5360, 167.9060],
          [176.1650, 440.9290, 171.1790],
          [173.0530, 438.9530, 172.1310],
          [170.3990, 440.6990, 174.2170],
          [166.7880, 440.0710, 175.2260],
          [164.4480, 442.1010, 173.0340],
          [160.9780, 443.6110, 173.3810],
          [158.8330, 443.9000, 170.2730],
          [160.1820, 442.9360, 166.8360],
          [163.5010, 444.7230, 166.1630],
          [164.6310, 442.4290, 163.3160],
          [166.6330, 444.2560, 160.6240],
          [166.1700, 447.5210, 162.5220],
          [168.1350, 449.9450, 164.6440],
          [168.0040, 449.0470, 168.3240],
          [168.4210, 451.0540, 171.5220],
          [169.0440, 450.0360, 175.1300],
          [166.1650, 450.3680, 177.6100],
          [166.7800, 449.8300, 181.3340],
          [163.8480, 448.9150, 183.5760],
          [163.3500, 443.9770, 187.9750],
          [165.0240, 442.4930, 184.8870],
          [168.1120, 443.0980, 182.7810],
          [168.1280, 445.9210, 180.2050],
          [166.7230, 445.0130, 176.8030],
          [167.0690, 446.1330, 173.1930],
          [164.0790, 447.8390, 171.5780],
          [163.4630, 449.0630, 168.0350],
          [164.4160, 452.7320, 168.0130],
          [166.7460, 455.3190, 166.5790],
          [167.5460, 459.0090, 166.8440],
          [170.2010, 459.9620, 169.3810],
          [169.5150, 456.8220, 171.4570],
          [170.7440, 454.3880, 168.7860],
          [172.8900, 451.6810, 170.3850],
          [173.2420, 449.3000, 167.4490],
          [171.4700, 447.3670, 164.7150],
          [169.7840, 443.9650, 164.9670],
          [170.8880, 441.4120, 162.3680],
          [169.3320, 438.1640, 163.6510],
          [166.7170, 436.9060, 166.1030],
          [166.2830, 433.5290, 167.7900],
          [162.5490, 433.4260, 168.5040],
          [162.6440, 430.3540, 170.7620],
          [164.9490, 432.1120, 173.2350],
          [163.6530, 435.5760, 172.2210],
          [167.2540, 436.7450, 171.8300],
          [168.3710, 439.3650, 169.3060],
          [171.8910, 439.6680, 167.8930],
          [173.1360, 443.2500, 167.5870],
          [176.0990, 444.7950, 165.8030],
          [177.1670, 447.8660, 167.8270],
          [177.4160, 451.2430, 166.1360],
          [181.1790, 451.0810, 166.7480],
          [181.3560, 448.4350, 163.9980],
          [180.6020, 448.8920, 160.3060],
          [177.8440, 446.8990, 158.6170],
          [176.7540, 446.3160, 155.0150],
          [173.3820, 444.6900, 154.4770],
          [169.6430, 444.9660, 153.9760],
          [167.6080, 446.9450, 156.5140],
          [163.9630, 447.9270, 156.8800],
          [163.1520, 451.2160, 155.1430],
          [160.3870, 453.6570, 156.0420],
          [158.9330, 453.6920, 152.5150],
          [159.8050, 453.2460, 148.8360],
          [161.2560, 456.7540, 148.4440],
          [164.9420, 455.8330, 148.8520],
          [167.3780, 455.9500, 145.9260],
          [171.0380, 455.0610, 145.5060],
          [173.5830, 457.5340, 146.9270],
          [171.2360, 458.9430, 149.5740],
          [171.8150, 459.8640, 153.2130],
          [169.5910, 457.9090, 155.6000],
          [168.4960, 458.3960, 159.2100],
          [167.0870, 455.9020, 161.7110],
          [163.6430, 457.0390, 162.8610],
          [161.8700, 456.3550, 166.1660],
          [160.7490, 452.9260, 164.8950],
          [164.2440, 451.8750, 163.8110],
          [163.4490, 452.3040, 160.1090],
          [165.6820, 454.0140, 157.5590],
          [164.2940, 457.2310, 156.0670],
          [165.6810, 459.7290, 153.5780],
          [167.3870, 462.7790, 155.0760],
          [170.2060, 465.8340, 161.8360],
          [172.1620, 462.7680, 160.6890],
          [174.0280, 461.7040, 163.8180],
          [175.2110, 458.5660, 161.9930],
          [176.5410, 458.2410, 158.4450],
          [174.5550, 455.7530, 156.3540],
          [174.3580, 455.6030, 152.5600],
          [171.8590, 453.6260, 150.4910],
          [173.6300, 451.2730, 148.0810],
          [170.4820, 450.1750, 146.2000],
          [166.8520, 451.1980, 145.7810],
          [164.0730, 450.1880, 148.1490],
          [162.6710, 446.7750, 147.1980],
          [159.1420, 445.8080, 148.2240],
          [158.3160, 442.3640, 149.5950],
          [155.0300, 443.1470, 151.4050],
          [152.6800, 446.0960, 151.8500],
          [155.0610, 447.4890, 154.4910],
          [158.2010, 445.3460, 154.0190],
          [160.8090, 447.5340, 152.2870],
          [164.4130, 446.3150, 152.1360],
          [167.3360, 448.5380, 151.1380],
          [171.0310, 447.6630, 150.9500],
          [172.8160, 450.1440, 153.2260],
          [176.3550, 450.4540, 154.5850],
          [176.6310, 452.0780, 158.0240],
          [180.1600, 453.0970, 159.0310],
          [179.9110, 454.9710, 162.3160],
          [178.7370, 458.4030, 163.4610],
          [179.1350, 461.9150, 162.0830],
          [180.3160, 464.9940, 163.9950],
          [195.1700, 454.0210, 194.1160],
          [196.4220, 450.5570, 193.1610],
          [195.0530, 447.5260, 191.3340],
          [195.7530, 445.4320, 194.4410],
          [195.6760, 446.8960, 197.9710],
          [199.1230, 447.9300, 199.1640],
          [198.4070, 447.0140, 202.7840],
          [198.1130, 443.3080, 203.5090],
          [195.0350, 441.7530, 205.1060],
          [194.7040, 440.1720, 208.5300],
          [191.7420, 437.9460, 207.6670],
          [190.9460, 435.0990, 205.3140],
          [188.0870, 435.6760, 202.8880],
          [184.8300, 433.7320, 203.1830],
          [184.1270, 431.8260, 199.9570],
          [180.5450, 430.8500, 200.7460],
          [178.2470, 433.3320, 199.0220],
          [178.2770, 433.9260, 195.2680],
          [176.3170, 437.2030, 195.0280],
          [177.7690, 440.7330, 194.9660],
          [181.3350, 439.3910, 195.1330],
          [182.5820, 441.3460, 192.0940],
          [182.3530, 444.9560, 193.3610],
          [186.0870, 445.2240, 194.0290],
          [188.7230, 447.5570, 192.6190],
          [190.8760, 444.5870, 191.5490],
          [188.2670, 443.5230, 188.9590],
          [190.1360, 444.3010, 185.7360],
          [188.8140, 444.1990, 182.1980],
          [185.1300, 444.1560, 181.3470],
          [182.1290, 442.2030, 182.6060],
          [181.3330, 441.2700, 178.9860],
          [184.3690, 440.6570, 176.7700],
          [184.4240, 439.7520, 173.0740],
          [187.1540, 437.5560, 171.6020],
          [187.5170, 439.9750, 168.6500],
          [188.4180, 443.0500, 170.7240],
          [191.4360, 445.2100, 169.8910],
          [193.7480, 445.6510, 172.8940],
          [196.8200, 447.8990, 173.0280],
          [199.4910, 448.6120, 175.6310],
          [201.2080, 451.9090, 176.3780],
          [203.6110, 453.4740, 178.8680],
          [201.6130, 454.9210, 181.7570],
          [202.0690, 458.3840, 183.2730],
          [200.4970, 458.3690, 186.7180],
          [197.7960, 455.8130, 187.5240],
          [195.1460, 455.8470, 184.7720],
          [193.6160, 452.5740, 185.9790],
          [189.8600, 452.3800, 185.3150],
          [189.9440, 455.8580, 183.7680],
          [189.7400, 457.4310, 180.3410],
          [193.1630, 457.9480, 178.8100],
          [194.6770, 460.5170, 176.4640],
          [197.9720, 460.6430, 174.6010],
          [200.7710, 462.5710, 176.3210],
          [203.7970, 463.7400, 174.3330],
          [207.1150, 464.5820, 175.9770],
          [210.1390, 466.5260, 174.8060],
          [210.8850, 463.9390, 172.1200],
          [213.6440, 461.8050, 173.6150],
          [211.1280, 459.2430, 174.9140],
          [208.0390, 457.7610, 173.2930],
          [204.4770, 458.9110, 173.9310],
          [202.5240, 457.8000, 177.0030],
          [198.9190, 457.7220, 178.2440],
          [197.5730, 459.9040, 181.0500],
          [194.2310, 460.2350, 182.8060],
          [192.1140, 462.6430, 180.7850],
          [189.0250, 463.1720, 178.7000],
          [187.1240, 465.8360, 176.8220],
          [188.2850, 466.3320, 173.2420],
          [191.5580, 464.4520, 173.8600],
          [189.9950, 461.1280, 174.9010],
          [191.8440, 458.2970, 173.1540],
          [190.4670, 455.2690, 174.9900],
          [189.7380, 453.5820, 178.3070],
          [192.3390, 451.8830, 180.5000],
          [191.2020, 448.5340, 181.9050],
          [194.3800, 447.0280, 183.4050],
          [197.8700, 447.9650, 184.5850],
          [200.9840, 445.8270, 184.9660],
          [202.8290, 447.6220, 187.7590],
          [206.2120, 445.9180, 187.3400],
          [206.6040, 447.1310, 183.7460],
          [204.4430, 450.2380, 184.1830],
          [202.2310, 449.1370, 181.2890],
          [198.6660, 450.3790, 180.7720],
          [196.2310, 448.2490, 178.7680],
          [194.0550, 450.4240, 176.5260],
          [190.8490, 449.7090, 174.6070],
          [190.4360, 452.3080, 171.8240],
          [187.3830, 454.5590, 171.7980],
          [186.4670, 452.9730, 168.4480],
          [185.9050, 449.6480, 170.2580],
          [182.8890, 448.8910, 172.4330],
          [183.5680, 448.6030, 176.1630],
          [181.3330, 447.5400, 179.0590],
          [182.7340, 447.5270, 182.5670],
          [183.3280, 449.2260, 185.8920],
          [184.8760, 452.6980, 185.6990],
          [185.7140, 455.4170, 188.2030],
          [182.9420, 457.9530, 188.8300],
          [183.2930, 461.5650, 189.9350],
          [182.7580, 462.1320, 193.6660],
          [180.0620, 464.7490, 193.1030],
          [176.4120, 463.9030, 192.5790],
          [177.1550, 460.1730, 192.6930],
          [173.8180, 459.5980, 194.4390],
          [172.0770, 461.1030, 191.4040],
          [173.6950, 458.5410, 189.0910],
          [171.0060, 456.0190, 188.1270],
          [170.8210, 453.3390, 185.4210],
          [170.0180, 454.8270, 182.0350],
          [171.8690, 458.1080, 182.6590],
          [174.2320, 459.6520, 180.1210],
          [177.8640, 459.9130, 181.2230],
          [180.8780, 461.8410, 179.9490],
          [184.5650, 461.8070, 180.8430],
          [186.0650, 464.7990, 182.6440],
          [189.5780, 466.2740, 182.6170],
          [190.7120, 463.6340, 185.1390],
          [189.2980, 460.6630, 183.2290],
          [186.3560, 460.3000, 185.6240],
          [182.7840, 459.6590, 184.5210],
          [180.2130, 462.3320, 185.3590],
          [176.5320, 462.8410, 184.6140],
          [175.8580, 465.0300, 181.5790],
          [178.1930, 465.7110, 174.9000],
          [180.9150, 464.1190, 172.7670],
          [181.2510, 460.5420, 174.1130],
          [178.4430, 457.9660, 174.2140],
          [178.5420, 456.4790, 177.7180],
          [175.4730, 455.1590, 179.5500],
          [175.4670, 454.0690, 183.1870],
          [174.1210, 450.5430, 183.6830],
          [173.9850, 450.6340, 187.5030],
          [174.3640, 452.9240, 190.5240],
          [177.5140, 454.2200, 192.1820],
          [179.3060, 451.6570, 194.3540],
          [181.4890, 453.2110, 197.0570],
          [184.7390, 451.3110, 197.6330],
          [186.4870, 454.0230, 199.6690],
          [185.5270, 457.3810, 201.1790],
          [186.7910, 459.0050, 197.9610],
          [186.6010, 456.0430, 195.5320],
          [183.3680, 455.3620, 193.6270],
          [182.9110, 452.9430, 190.7280],
          [180.0190, 452.5230, 188.3060],
          [179.0910, 450.0250, 185.5960],
          [178.9850, 451.7840, 182.2280],
          [179.0510, 451.0080, 178.5100],
          [181.0530, 453.1550, 176.0780],
          [180.1320, 452.7170, 172.4060],
          [182.0570, 455.4530, 170.6120],
          [181.8140, 459.1820, 170.0770],
          [178.5990, 461.0580, 169.3130],
          [177.7620, 463.8650, 166.9020]]),
  'residue_ids': tensor([  4.,   4.,   4.,  ..., 182., 182., 182.]),
  'chain_ids': tensor([0., 0., 0.,  ..., 1., 1., 1.])},
 'feature_complex': {'atom_types': tensor([[0.],
          [1.],
          [2.],
          ...,
          [1.],
          [2.],
          [3.]]),
  'element_types': tensor([[3.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [2.]]),
  'residue_types': tensor([[16.],
          [16.],
          [16.],
          ...,
          [ 0.],
          [ 0.],
          [ 0.]]),
  'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],
          [132.6810, 428.2520, 163.1550],
          [133.5150, 428.6750, 161.9500],
          ...,
          [177.7620, 463.8650, 166.9020],
          [177.4130, 465.0800, 167.7550],
          [176.8000, 464.9490, 168.8150]]),
  'residue_coordinates': tensor([[132.6810, 428.2520, 163.1550],
          [133.5560, 429.2490, 159.5910],
          [133.8750, 432.8980, 160.6290],
          [136.1110, 431.8050, 163.5130],
          [138.3420, 429.8920, 161.0830],
          [138.5230, 432.9090, 158.7600],
          [139.4710, 435.1730, 161.6750],
          [142.1200, 432.6310, 162.6740],
          [143.5890, 432.7500, 159.1600],
          [143.6190, 436.5600, 159.1540],
          [145.2600, 436.8040, 162.5830],
          [147.8380, 434.1320, 161.7330],
          [148.7390, 435.8920, 158.4790],
          [149.0640, 439.2630, 160.2240],
          [151.2220, 437.7670, 162.9840],
          [153.4710, 435.8030, 160.6120],
          [153.9420, 438.9180, 158.4710],
          [156.1140, 440.2660, 161.3150],
          [158.3330, 437.1860, 161.4250],
          [161.8090, 436.4120, 160.1100],
          [161.1120, 432.8370, 158.9350],
          [157.4020, 432.0710, 159.3930],
          [156.5290, 428.3870, 159.0310],
          [154.0400, 428.1160, 156.1500],
          [153.3520, 424.6590, 154.7510],
          [152.8370, 424.8210, 150.9870],
          [150.9700, 422.4160, 148.7320],
          [153.5900, 422.6860, 145.9670],
          [157.3050, 423.3410, 145.6600],
          [158.4860, 426.6730, 144.2700],
          [160.7730, 427.0880, 141.2590],
          [163.9630, 429.0530, 141.9600],
          [166.4230, 427.4330, 139.5550],
          [167.6170, 430.8330, 138.3250],
          [170.5180, 431.9210, 140.5320],
          [169.7800, 435.6680, 140.3160],
          [167.6790, 437.8790, 142.6110],
          [167.3350, 435.1910, 145.2890],
          [168.9090, 437.3000, 148.0580],
          [166.1860, 439.9440, 148.5750],
          [164.1010, 438.1300, 151.1890],
          [163.0190, 438.8690, 154.7450],
          [164.6750, 435.6560, 155.9860],
          [168.1170, 436.6550, 154.6330],
          [169.8040, 437.0930, 158.0010],
          [173.2680, 438.4850, 158.5750],
          [174.9270, 440.8490, 156.1460],
          [175.9170, 440.8010, 152.5000],
          [179.6060, 441.0910, 153.4330],
          [180.8000, 439.9830, 156.8710],
          [184.2540, 439.9980, 158.4580],
          [185.5840, 437.2970, 160.7670],
          [186.4140, 439.9000, 163.4480],
          [182.9430, 441.4570, 163.7500],
          [181.4930, 442.0910, 167.2100],
          [178.0780, 440.5360, 167.9060],
          [176.1650, 440.9290, 171.1790],
          [173.0530, 438.9530, 172.1310],
          [170.3990, 440.6990, 174.2170],
          [166.7880, 440.0710, 175.2260],
          [164.4480, 442.1010, 173.0340],
          [160.9780, 443.6110, 173.3810],
          [158.8330, 443.9000, 170.2730],
          [160.1820, 442.9360, 166.8360],
          [163.5010, 444.7230, 166.1630],
          [164.6310, 442.4290, 163.3160],
          [166.6330, 444.2560, 160.6240],
          [166.1700, 447.5210, 162.5220],
          [168.1350, 449.9450, 164.6440],
          [168.0040, 449.0470, 168.3240],
          [168.4210, 451.0540, 171.5220],
          [169.0440, 450.0360, 175.1300],
          [166.1650, 450.3680, 177.6100],
          [166.7800, 449.8300, 181.3340],
          [163.8480, 448.9150, 183.5760],
          [163.3500, 443.9770, 187.9750],
          [165.0240, 442.4930, 184.8870],
          [168.1120, 443.0980, 182.7810],
          [168.1280, 445.9210, 180.2050],
          [166.7230, 445.0130, 176.8030],
          [167.0690, 446.1330, 173.1930],
          [164.0790, 447.8390, 171.5780],
          [163.4630, 449.0630, 168.0350],
          [164.4160, 452.7320, 168.0130],
          [166.7460, 455.3190, 166.5790],
          [167.5460, 459.0090, 166.8440],
          [170.2010, 459.9620, 169.3810],
          [169.5150, 456.8220, 171.4570],
          [170.7440, 454.3880, 168.7860],
          [172.8900, 451.6810, 170.3850],
          [173.2420, 449.3000, 167.4490],
          [171.4700, 447.3670, 164.7150],
          [169.7840, 443.9650, 164.9670],
          [170.8880, 441.4120, 162.3680],
          [169.3320, 438.1640, 163.6510],
          [166.7170, 436.9060, 166.1030],
          [166.2830, 433.5290, 167.7900],
          [162.5490, 433.4260, 168.5040],
          [162.6440, 430.3540, 170.7620],
          [164.9490, 432.1120, 173.2350],
          [163.6530, 435.5760, 172.2210],
          [167.2540, 436.7450, 171.8300],
          [168.3710, 439.3650, 169.3060],
          [171.8910, 439.6680, 167.8930],
          [173.1360, 443.2500, 167.5870],
          [176.0990, 444.7950, 165.8030],
          [177.1670, 447.8660, 167.8270],
          [177.4160, 451.2430, 166.1360],
          [181.1790, 451.0810, 166.7480],
          [181.3560, 448.4350, 163.9980],
          [180.6020, 448.8920, 160.3060],
          [177.8440, 446.8990, 158.6170],
          [176.7540, 446.3160, 155.0150],
          [173.3820, 444.6900, 154.4770],
          [169.6430, 444.9660, 153.9760],
          [167.6080, 446.9450, 156.5140],
          [163.9630, 447.9270, 156.8800],
          [163.1520, 451.2160, 155.1430],
          [160.3870, 453.6570, 156.0420],
          [158.9330, 453.6920, 152.5150],
          [159.8050, 453.2460, 148.8360],
          [161.2560, 456.7540, 148.4440],
          [164.9420, 455.8330, 148.8520],
          [167.3780, 455.9500, 145.9260],
          [171.0380, 455.0610, 145.5060],
          [173.5830, 457.5340, 146.9270],
          [171.2360, 458.9430, 149.5740],
          [171.8150, 459.8640, 153.2130],
          [169.5910, 457.9090, 155.6000],
          [168.4960, 458.3960, 159.2100],
          [167.0870, 455.9020, 161.7110],
          [163.6430, 457.0390, 162.8610],
          [161.8700, 456.3550, 166.1660],
          [160.7490, 452.9260, 164.8950],
          [164.2440, 451.8750, 163.8110],
          [163.4490, 452.3040, 160.1090],
          [165.6820, 454.0140, 157.5590],
          [164.2940, 457.2310, 156.0670],
          [165.6810, 459.7290, 153.5780],
          [167.3870, 462.7790, 155.0760],
          [170.2060, 465.8340, 161.8360],
          [172.1620, 462.7680, 160.6890],
          [174.0280, 461.7040, 163.8180],
          [175.2110, 458.5660, 161.9930],
          [176.5410, 458.2410, 158.4450],
          [174.5550, 455.7530, 156.3540],
          [174.3580, 455.6030, 152.5600],
          [171.8590, 453.6260, 150.4910],
          [173.6300, 451.2730, 148.0810],
          [170.4820, 450.1750, 146.2000],
          [166.8520, 451.1980, 145.7810],
          [164.0730, 450.1880, 148.1490],
          [162.6710, 446.7750, 147.1980],
          [159.1420, 445.8080, 148.2240],
          [158.3160, 442.3640, 149.5950],
          [155.0300, 443.1470, 151.4050],
          [152.6800, 446.0960, 151.8500],
          [155.0610, 447.4890, 154.4910],
          [158.2010, 445.3460, 154.0190],
          [160.8090, 447.5340, 152.2870],
          [164.4130, 446.3150, 152.1360],
          [167.3360, 448.5380, 151.1380],
          [171.0310, 447.6630, 150.9500],
          [172.8160, 450.1440, 153.2260],
          [176.3550, 450.4540, 154.5850],
          [176.6310, 452.0780, 158.0240],
          [180.1600, 453.0970, 159.0310],
          [179.9110, 454.9710, 162.3160],
          [178.7370, 458.4030, 163.4610],
          [179.1350, 461.9150, 162.0830],
          [180.3160, 464.9940, 163.9950],
          [195.1700, 454.0210, 194.1160],
          [196.4220, 450.5570, 193.1610],
          [195.0530, 447.5260, 191.3340],
          [195.7530, 445.4320, 194.4410],
          [195.6760, 446.8960, 197.9710],
          [199.1230, 447.9300, 199.1640],
          [198.4070, 447.0140, 202.7840],
          [198.1130, 443.3080, 203.5090],
          [195.0350, 441.7530, 205.1060],
          [194.7040, 440.1720, 208.5300],
          [191.7420, 437.9460, 207.6670],
          [190.9460, 435.0990, 205.3140],
          [188.0870, 435.6760, 202.8880],
          [184.8300, 433.7320, 203.1830],
          [184.1270, 431.8260, 199.9570],
          [180.5450, 430.8500, 200.7460],
          [178.2470, 433.3320, 199.0220],
          [178.2770, 433.9260, 195.2680],
          [176.3170, 437.2030, 195.0280],
          [177.7690, 440.7330, 194.9660],
          [181.3350, 439.3910, 195.1330],
          [182.5820, 441.3460, 192.0940],
          [182.3530, 444.9560, 193.3610],
          [186.0870, 445.2240, 194.0290],
          [188.7230, 447.5570, 192.6190],
          [190.8760, 444.5870, 191.5490],
          [188.2670, 443.5230, 188.9590],
          [190.1360, 444.3010, 185.7360],
          [188.8140, 444.1990, 182.1980],
          [185.1300, 444.1560, 181.3470],
          [182.1290, 442.2030, 182.6060],
          [181.3330, 441.2700, 178.9860],
          [184.3690, 440.6570, 176.7700],
          [184.4240, 439.7520, 173.0740],
          [187.1540, 437.5560, 171.6020],
          [187.5170, 439.9750, 168.6500],
          [188.4180, 443.0500, 170.7240],
          [191.4360, 445.2100, 169.8910],
          [193.7480, 445.6510, 172.8940],
          [196.8200, 447.8990, 173.0280],
          [199.4910, 448.6120, 175.6310],
          [201.2080, 451.9090, 176.3780],
          [203.6110, 453.4740, 178.8680],
          [201.6130, 454.9210, 181.7570],
          [202.0690, 458.3840, 183.2730],
          [200.4970, 458.3690, 186.7180],
          [197.7960, 455.8130, 187.5240],
          [195.1460, 455.8470, 184.7720],
          [193.6160, 452.5740, 185.9790],
          [189.8600, 452.3800, 185.3150],
          [189.9440, 455.8580, 183.7680],
          [189.7400, 457.4310, 180.3410],
          [193.1630, 457.9480, 178.8100],
          [194.6770, 460.5170, 176.4640],
          [197.9720, 460.6430, 174.6010],
          [200.7710, 462.5710, 176.3210],
          [203.7970, 463.7400, 174.3330],
          [207.1150, 464.5820, 175.9770],
          [210.1390, 466.5260, 174.8060],
          [210.8850, 463.9390, 172.1200],
          [213.6440, 461.8050, 173.6150],
          [211.1280, 459.2430, 174.9140],
          [208.0390, 457.7610, 173.2930],
          [204.4770, 458.9110, 173.9310],
          [202.5240, 457.8000, 177.0030],
          [198.9190, 457.7220, 178.2440],
          [197.5730, 459.9040, 181.0500],
          [194.2310, 460.2350, 182.8060],
          [192.1140, 462.6430, 180.7850],
          [189.0250, 463.1720, 178.7000],
          [187.1240, 465.8360, 176.8220],
          [188.2850, 466.3320, 173.2420],
          [191.5580, 464.4520, 173.8600],
          [189.9950, 461.1280, 174.9010],
          [191.8440, 458.2970, 173.1540],
          [190.4670, 455.2690, 174.9900],
          [189.7380, 453.5820, 178.3070],
          [192.3390, 451.8830, 180.5000],
          [191.2020, 448.5340, 181.9050],
          [194.3800, 447.0280, 183.4050],
          [197.8700, 447.9650, 184.5850],
          [200.9840, 445.8270, 184.9660],
          [202.8290, 447.6220, 187.7590],
          [206.2120, 445.9180, 187.3400],
          [206.6040, 447.1310, 183.7460],
          [204.4430, 450.2380, 184.1830],
          [202.2310, 449.1370, 181.2890],
          [198.6660, 450.3790, 180.7720],
          [196.2310, 448.2490, 178.7680],
          [194.0550, 450.4240, 176.5260],
          [190.8490, 449.7090, 174.6070],
          [190.4360, 452.3080, 171.8240],
          [187.3830, 454.5590, 171.7980],
          [186.4670, 452.9730, 168.4480],
          [185.9050, 449.6480, 170.2580],
          [182.8890, 448.8910, 172.4330],
          [183.5680, 448.6030, 176.1630],
          [181.3330, 447.5400, 179.0590],
          [182.7340, 447.5270, 182.5670],
          [183.3280, 449.2260, 185.8920],
          [184.8760, 452.6980, 185.6990],
          [185.7140, 455.4170, 188.2030],
          [182.9420, 457.9530, 188.8300],
          [183.2930, 461.5650, 189.9350],
          [182.7580, 462.1320, 193.6660],
          [180.0620, 464.7490, 193.1030],
          [176.4120, 463.9030, 192.5790],
          [177.1550, 460.1730, 192.6930],
          [173.8180, 459.5980, 194.4390],
          [172.0770, 461.1030, 191.4040],
          [173.6950, 458.5410, 189.0910],
          [171.0060, 456.0190, 188.1270],
          [170.8210, 453.3390, 185.4210],
          [170.0180, 454.8270, 182.0350],
          [171.8690, 458.1080, 182.6590],
          [174.2320, 459.6520, 180.1210],
          [177.8640, 459.9130, 181.2230],
          [180.8780, 461.8410, 179.9490],
          [184.5650, 461.8070, 180.8430],
          [186.0650, 464.7990, 182.6440],
          [189.5780, 466.2740, 182.6170],
          [190.7120, 463.6340, 185.1390],
          [189.2980, 460.6630, 183.2290],
          [186.3560, 460.3000, 185.6240],
          [182.7840, 459.6590, 184.5210],
          [180.2130, 462.3320, 185.3590],
          [176.5320, 462.8410, 184.6140],
          [175.8580, 465.0300, 181.5790],
          [178.1930, 465.7110, 174.9000],
          [180.9150, 464.1190, 172.7670],
          [181.2510, 460.5420, 174.1130],
          [178.4430, 457.9660, 174.2140],
          [178.5420, 456.4790, 177.7180],
          [175.4730, 455.1590, 179.5500],
          [175.4670, 454.0690, 183.1870],
          [174.1210, 450.5430, 183.6830],
          [173.9850, 450.6340, 187.5030],
          [174.3640, 452.9240, 190.5240],
          [177.5140, 454.2200, 192.1820],
          [179.3060, 451.6570, 194.3540],
          [181.4890, 453.2110, 197.0570],
          [184.7390, 451.3110, 197.6330],
          [186.4870, 454.0230, 199.6690],
          [185.5270, 457.3810, 201.1790],
          [186.7910, 459.0050, 197.9610],
          [186.6010, 456.0430, 195.5320],
          [183.3680, 455.3620, 193.6270],
          [182.9110, 452.9430, 190.7280],
          [180.0190, 452.5230, 188.3060],
          [179.0910, 450.0250, 185.5960],
          [178.9850, 451.7840, 182.2280],
          [179.0510, 451.0080, 178.5100],
          [181.0530, 453.1550, 176.0780],
          [180.1320, 452.7170, 172.4060],
          [182.0570, 455.4530, 170.6120],
          [181.8140, 459.1820, 170.0770],
          [178.5990, 461.0580, 169.3130],
          [177.7620, 463.8650, 166.9020]]),
  'residue_ids': tensor([  4.,   4.,   4.,  ..., 182., 182., 182.]),
  'chain_ids': tensor([0., 0., 0.,  ..., 1., 1., 1.])},
 'id': '8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
 'sample_id': '8phr__X4_UNDEFINED-R--8phr__W4_UNDEFINED-L',
 'target_id': '8phr__X4_UNDEFINED-R--8phr__W4_UNDEFINED-L'}
# Since we used the default option of crop_equal_monomer_shapes, we should expect feature and target complex coords are identical shapes
assert (
    data_item["feature_complex"]["atom_coordinates"].shape
    == data_item["target_complex"]["atom_coordinates"].shape
)

data_item["feature_complex"]["atom_coordinates"].shape
torch.Size([1316, 3])
help(PinderDataset)
Help on class PinderDataset in module pinder.core.loader.dataset:

class PinderDataset(torch.utils.data.dataset.Dataset)
 |  PinderDataset(split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms_target: 'list[StructureTransform]' = [], structure_transforms_feature: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'
 |  
 |  Method resolution order:
 |      PinderDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, idx: 'int') -> 'dict[str, dict[str, torch.Tensor] | torch.Tensor]'
 |  
 |  __init__(self, split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms_target: 'list[StructureTransform]' = [], structure_transforms_feature: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  __len__(self) -> 'int'
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __annotations__ = {}
 |  
 |  __parameters__ = ()
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from torch.utils.data.dataset.Dataset:
 |  
 |  __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from torch.utils.data.dataset.Dataset:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes inherited from torch.utils.data.dataset.Dataset:
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  ----------------------------------------------------------------------
 |  Class methods inherited from typing.Generic:
 |  
 |  __class_getitem__(params) from builtins.type
 |  
 |  __init_subclass__(*args, **kwargs) from builtins.type
 |      This method is called when a class is subclassed.
 |      
 |      The default implementation does nothing. It may be
 |      overridden to extend subclasses.

Torch DataLoader for PinderDataset#

The PinderDataset can be served by a torch.utils.data.DataLoader.

There is a convenience function pinder.core.loader.dataset.get_torch_loader for taking a PinderDataset and returning a DataLoader for the dataset object.

We can leverage the default collate_fn (pinder.core.loader.dataset.collate_batch) to merge multiple systems (Dataset items) to create mini-batches of tensors:

from pinder.core.loader.dataset import collate_batch, get_torch_loader
from torch.utils.data import DataLoader

batch_size = 2
train_dataloader = get_torch_loader(
    train_dataset, 
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=0, 
)
assert isinstance(train_dataloader, DataLoader)
assert hasattr(train_dataloader, "dataset")

# Get a batch from the dataloader
batch = next(iter(train_dataloader))

# expected batch dict keys
assert set(batch.keys()) == {
    "target_complex",
    "feature_complex",
    "id",
    "sample_id",
    "target_id",
}
assert isinstance(batch["target_complex"], dict)
assert isinstance(batch["target_complex"]["atom_coordinates"], torch.Tensor)
feature_coords = batch["feature_complex"]["atom_coordinates"]
# Ensure batch size propagates to tensor dims
assert feature_coords.shape[0] == batch_size
# Ensure coordinates have dim 3
assert feature_coords.shape[2] == 3
2024-11-15 12:13:41,136 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
2024-11-15 12:13:41,860 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7

Torch geometric Dataset#

# Make sure to install torch_cluster
# !pip install torch_cluster
from pinder.core.loader.dataset import PPIDataset
from pinder.core.loader.geodata import NodeRepresentation

nodes = {NodeRepresentation("atom"), NodeRepresentation("residue")}

train_dataset = PPIDataset(
    node_types=nodes,
    split="train",
    monomer1="holo_receptor",
    monomer2="holo_ligand",
    limit_by=5,
    force_reload=True,
    parallel=False,
)
assert len(train_dataset) == 5

train_dataset
Processing...
2024-11-15 12:13:49,610 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:13:50,822 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
2024-11-15 12:13:51,243 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:13:52,054 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:13:52,436 | pinder.core.loader.dataset:550 | INFO : Finished processing, only 5 systems
Done!
PPIDataset(5)
from torch_geometric.data import HeteroData
from pinder.core import get_index

# Here we check that we correctly capped the number of systems to load to 5 (via limit_by)
pindex = get_index()
raw_ids = set(train_dataset.raw_file_names)
assert len(raw_ids.intersection(set(pindex.id))) == 5
processed_ids = {f.stem for f in train_dataset.processed_file_names}
# Here we ensure that all 5 ids got processed and saved as .pt file on disk
assert len(processed_ids.intersection(set(pindex.id))) == 5

# Let's get an item from the dataset by index 
data_item = train_dataset[0]
assert isinstance(data_item, HeteroData)
data_item
PairedPDB(
  ligand_residue={
    residueid=[121, 1],
    pos=[121, 3],
    edge_index=[2, 1210],
    chain=[1],
  },
  receptor_residue={
    residueid=[144, 1],
    pos=[144, 3],
    edge_index=[2, 1440],
    chain=[1],
  },
  ligand_atom={
    x=[958, 1],
    pos=[958, 3],
    edge_index=[2, 9580],
  },
  receptor_atom={
    x=[1119, 1],
    pos=[1119, 3],
    edge_index=[2, 11190],
  },
  pdb={
    id=[1],
    num_nodes=1,
  }
)
data_item.num_node_features
{'ligand_residue': 0,
 'receptor_residue': 0,
 'ligand_atom': 1,
 'receptor_atom': 1,
 'pdb': 0}
data_item["ligand_atom"]
{'x': tensor([[ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [10.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [22.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 7.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [11.],
        [30.],
        [ 5.],
        [14.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [11.],
        [30.],
        [ 5.],
        [14.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 7.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [10.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 6.],
        [31.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [22.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [33.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [11.],
        [30.],
        [ 5.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 6.],
        [31.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [26.],
        [25.],
        [11.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 7.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [34.],
        [30.],
        [ 4.],
        [13.],
        [16.],
        [18.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 6.],
        [31.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 7.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 7.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [22.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [11.],
        [30.],
        [ 5.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [33.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [26.],
        [25.],
        [11.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [11.],
        [30.],
        [ 5.],
        [14.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [17.],
        [ 5.],
        [ 9.],
        [20.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [10.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [27.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [31.],
        [28.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [35.],
        [23.],
        [ 7.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [21.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [29.],
        [12.],
        [24.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [10.],
        [23.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [19.],
        [32.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 8.],
        [15.],
        [ 7.],
        [25.],
        [11.],
        [30.],
        [ 5.],
        [ 0.],
        [ 1.],
        [ 2.],
        [ 3.]]), 'pos': tensor([[207.6910, 164.0670, 172.7860],
        [208.6350, 163.8290, 173.9130],
        [209.3790, 165.0800, 174.3180],
        ...,
        [240.2730, 173.0310, 178.0680],
        [240.1680, 174.1280, 177.0260],
        [240.3170, 173.8780, 175.8300]]), 'edge_index': tensor([[  1,   2,   4,  ..., 946, 894, 893],
        [  0,   0,   0,  ..., 957, 957, 957]])}
# We can also get an item by its system ID 
data_item = train_dataset.get_filename("8phr__X4_UNDEFINED--8phr__W4_UNDEFINED")
data_item
PairedPDB(
  ligand_residue={
    residueid=[158, 1],
    pos=[158, 3],
    edge_index=[2, 1580],
    chain=[1],
  },
  receptor_residue={
    residueid=[171, 1],
    pos=[171, 3],
    edge_index=[2, 1710],
    chain=[1],
  },
  ligand_atom={
    x=[1198, 1],
    pos=[1198, 3],
    edge_index=[2, 11980],
  },
  receptor_atom={
    x=[1358, 1],
    pos=[1358, 3],
    edge_index=[2, 13580],
  },
  pdb={
    id=[1],
    num_nodes=1,
  }
)
train_dataset.print_summary()
PPIDataset (#graphs=5):
+------------+----------+----------+
|            |   #nodes |   #edges |
|------------+----------+----------|
| mean       |   2612.4 |        0 |
| std        |   1631.6 |        0 |
| min        |    999   |        0 |
| quantile25 |   1600   |        0 |
| median     |   2343   |        0 |
| quantile75 |   2886   |        0 |
| max        |   5234   |        0 |
+------------+----------+----------+
Number of nodes per node type:
+------------+------------------+--------------------+---------------+-----------------+-------+
|            |   ligand_residue |   receptor_residue |   ligand_atom |   receptor_atom |   pdb |
|------------+------------------+--------------------+---------------+-----------------+-------|
| mean       |            140.2 |              152.8 |        1103.4 |          1215   |     1 |
| std        |             90.2 |               87.5 |         738.9 |           717.1 |     0 |
| min        |             58   |               62   |         426   |           452   |     1 |
| quantile25 |             78   |               97   |         622   |           802   |     1 |
| median     |            121   |              144   |         958   |          1119   |     1 |
| quantile75 |            158   |              171   |        1198   |          1358   |     1 |
| max        |            286   |              290   |        2313   |          2344   |     1 |
+------------+------------------+--------------------+---------------+-----------------+-------+

PairedPDB torch-geometric HeteroData object#

The PPIDataset represents its data items as PairedPDB objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via torch_geometric.data.HeteroData, holding multiple node and/or edge types in disjunct storage objects.

It leverages the PinderLoader to apply optional filters and/or transforms and implements a caching system by saving processed systems to disk in .pt files. The PairedPDB implements a conversion method .from_pinder_system that takes in a PinderSystem and converts it to a PairedPDB object.

For more details on the torch-geometric Dataset APIs, please refer to the tutorials.

from pinder.core.loader.geodata import PairedPDB
from torch_geometric.data import HeteroData

pinder_id = "3s9d__B1_P48551--3s9d__A1_P01563"
system = PinderSystem(pinder_id)

holo_data = PairedPDB.from_pinder_system(
    system=system,
    monomer1="holo_receptor", monomer2="holo_ligand",
    node_types=nodes,
)
assert isinstance(holo_data, HeteroData)
expected_node_types = [
    'ligand_residue', 'receptor_residue', 'ligand_atom', 'receptor_atom'
]
assert holo_data.num_nodes == 2780
assert holo_data.num_edges == 0
assert isinstance(holo_data.num_node_features, dict)
expected_num_feats = {
    'ligand_residue': 0,
    'receptor_residue': 0,
    'ligand_atom': 1,
    'receptor_atom': 1
}
for k, v in expected_num_feats.items():
    assert holo_data.num_node_features[k] == v

assert holo_data.node_types == expected_node_types


holo_data
2024-11-15 12:13:53,464 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=15
PairedPDB(
  ligand_residue={
    residueid=[127, 1],
    pos=[127, 3],
    edge_index=[2, 1270],
    chain=[1],
  },
  receptor_residue={
    residueid=[180, 1],
    pos=[180, 3],
    edge_index=[2, 1800],
    chain=[1],
  },
  ligand_atom={
    x=[1032, 1],
    pos=[1032, 3],
    edge_index=[2, 10320],
  },
  receptor_atom={
    x=[1441, 1],
    pos=[1441, 3],
    edge_index=[2, 14410],
  }
)
# You can target specific monomers (apo/holo/pred) for the receptor and ligand
apo_data = PairedPDB.from_pinder_system(
    system=system,
    monomer1="apo_receptor", monomer2="apo_ligand",
    node_types=nodes,
)
assert isinstance(apo_data, HeteroData)

assert apo_data.num_nodes == 3437
assert apo_data.num_edges == 0
assert isinstance(apo_data.num_node_features, dict)
expected_num_feats = {
    'ligand_residue': 0,
    'receptor_residue': 0,
    'ligand_atom': 1,
    'receptor_atom': 1,
}
for k, v in expected_num_feats.items():
    assert apo_data.num_node_features[k] == v

assert apo_data.node_types == expected_node_types

apo_data
PairedPDB(
  ligand_residue={
    residueid=[165, 1],
    pos=[165, 3],
    edge_index=[2, 1650],
    chain=[1],
  },
  receptor_residue={
    residueid=[212, 1],
    pos=[212, 3],
    edge_index=[2, 2120],
    chain=[1],
  },
  ligand_atom={
    x=[1350, 1],
    pos=[1350, 3],
    edge_index=[2, 13500],
  },
  receptor_atom={
    x=[1710, 1],
    pos=[1710, 3],
    edge_index=[2, 17100],
  }
)

Torch geometric DataLoader#

The PPIDataset can be served by a torch_geometric.DataLoader.

There is a convenience function pinder.core.loader.dataset.get_geo_loader for taking a PPIDataset and returning a DataLoader for the dataset object.

from pinder.core.loader.dataset import get_geo_loader
from torch_geometric.loader import DataLoader


loader = get_geo_loader(train_dataset)

assert isinstance(loader, DataLoader)
assert hasattr(loader, "dataset")
ds = loader.dataset
assert len(ds) == 5


loader
<torch_geometric.loader.dataloader.DataLoader at 0x7f697dbfdb40>
ds
PPIDataset(5)

Implementing your own PyTorch Dataset & DataLoader for pinder#

While the previous sections covered two example implementations of leveraging the torch and torch-geometric APIs, below we will illustrate how you can write your own PyTorch data pipeline to feed your model.

We will focus on writing a minimalistic example that simply fetches PDB files and returns the coordinates as tensors. We will also include an example of a sampler function for the dataloader to implement diversity sampling in your workflow.

Defining the Dataset#

Below we will write a barebones torch.utils.data.Dataset object that implements at a minimum:

  • __init__ method

  • __len__ method returning the number of items in the dataset

  • __getitem__ method that returns an item in the dataset

We will also add an option to include apo and predicted monomer types by adding a monomer_priority argument to our dataset. The argument will be set to “random” by default, indicating that we want to randomly sample a monomer type from the set of monomers available for a given system. We could also target a specific monomer type by setting this argument to one of “holo”, “apo”, “predicted”. Not every system in the training set has apo or predicted structures available, so we will also add a fallback_to_holo argument to indicate whether we want to use holo monomers when the selected monomer type is not available.

We also define two interfaces for applying filters to our dataset:

  1. metadata_filter: a query string to apply to the pinder metadata pandas DataFrame

  2. system_filters: list[PinderFilterBase]: a list of filters that inheret a base class, PinderFilterBase, which serves as the abstraction layer for defining PinderSystem-based filters

  3. structure_filters: list[StructureFilter]: a list of filters that inheret a base class, StructureFilter, which serves as the abstraction layer for defining Structure-based filters

import numpy as np
from numpy.typing import NDArray
from torch.utils.data import Dataset
from pinder.core import get_index, get_metadata
from pinder.core.loader.loader import _create_target_feature_complex, select_monomer
from pinder.core.loader.structure import Structure

index = get_index()
metadata = get_metadata()


class CustomPinderDataset(Dataset):
    def __init__(
        self, 
        split: str, 
        monomer_priority: str = "random", 
        fallback_to_holo: bool = True, 
        crop_equal_monomer_shapes: bool = True,
        use_canonical_apo: bool = True,
        transform=None, 
        target_transform=None,
        structure_filters: list[filters.StructureFilter] = [],
        system_filters: list[filters.PinderFilterBase] = [],
        metadata_filter: str | None = None,
        max_load_attempts: int = 10,
    ) -> None:
        # Which split/mode we are using for the dataset instance
        self.split = split
        self.monomer_priority = monomer_priority
        self.fallback_to_holo = fallback_to_holo
        self.crop_equal_monomer_shapes = crop_equal_monomer_shapes

        # Optional transform and target transform to apply (will be covered shortly)
        self.transform = transform
        self.target_transform = target_transform
        # Optional system-level filters to apply
        self.system_filters = system_filters
        # Optional structure filters to apply
        self.structure_filters = structure_filters

        # Maximum number of times to try sampling another index from the dataset until an exception is raised
        self.max_load_attempts = max_load_attempts
        # Whether we should use canonical apo structures (apo_R/L_pdb columns in pinder index) if apo monomers are selected
        self.use_canonical_apo = use_canonical_apo
        
        # Define the subset of the pinder index and metadata corresponding to the split of our dataset instance 
        self.index = index.query(f'split == "{split}"').reset_index(drop=True)
        self.metadata = metadata[metadata["id"].isin(set(self.index.id))].reset_index(drop=True)
        if metadata_filter:
            try:
                self.metadata = self.metadata.query(metadata_filter).reset_index(drop=True)
            except Exception as e:
                print(f"Failed to apply metadata_filter={metadata_filter}: {e}")
        
        self.index = self.index[self.index["id"].isin(set(self.metadata.id))].reset_index(drop=True)

    def __len__(self):
        return len(self.index)
        
    def __getitem__(self, idx: int) -> tuple[NDArray[np.double], NDArray[np.double]]:
        valid_idx = False
        attempts = 0
        while not valid_idx and attempts < self.max_load_attempts:
            attempts += 1
            row = self.index.iloc[idx]
            system = PinderSystem(row.id)
            
            system = self.apply_system_filters(system)
            if not isinstance(system, PinderSystem):
                continue

            selected_monomers = select_monomer(
                row,
                self.monomer_priority,
                self.fallback_to_holo,
                self.use_canonical_apo,
            )
            # With the system and selected_monomers objects, we can now create a pair of dimer complexes
            # Below we leverage the existing utility from the PinderLoader (_create_target_feature_complex)
            target_complex, feature_complex = _create_target_feature_complex(
                system, selected_monomers, self.crop_equal_monomer_shapes, self.fallback_to_holo
            )
            valid_idx = self.apply_structure_filters(target_complex)
            if not valid_idx:
                # Try another index before raising IndexError
                idx = random.choice(list(range(len(self))))

        if not valid_idx:
            raise IndexError(
                f"Unable to find a valid item in the dataset satisfying filters at {idx} after {attempts} attempts!"
            )
        if self.transform is not None:
            feature_complex = self.transform(feature_complex)
        if self.target_transform is not None:
            target_complex = self.target_transform(target_complex)
        return feature_complex, target_complex

    def apply_structure_filters(self, structure: Structure) -> bool:
        pass_filters = True
        for structure_filter in self.structure_filters:
            if not structure_filter(structure):
                pass_filters = False
                break
        return pass_filters

    def apply_system_filters(self, system: PinderSystem) -> PinderSystem | bool:
        for system_filter in self.system_filters:
            if isinstance(system_filter, filters.PinderFilterBase):
                if not base_filter(system):
                    return False
        return system

    def __repr__(self) -> str:
        return f"CustomPinderDataset(split={self.split}, monomers={self.monomer_priority}, systems={len(self)})"
# Note the selected monomers indicated by Structure.pinder_id attributes. Since we enabled cropping, the feature and target complex AtomArray have identical shapes 
test_data = CustomPinderDataset(split="test")
test_data[0]
(Structure(
     filepath=/home/runner/.local/share/pinder/2024-02/test_set_pdbs/7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L.pdb,
     uniprot_map=<class 'pandas.core.frame.DataFrame'> with shape (294, 14),
     pinder_id='7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L',
     atom_array=<class 'biotite.structure.AtomArray'> with shape (2092,),
     pdb_engine='fastpdb',
 ),
 Structure(
     filepath=/home/runner/.local/share/pinder/2024-02/test_set_pdbs/7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L.pdb,
     uniprot_map=<class 'pandas.core.frame.DataFrame'> with shape (294, 14),
     pinder_id='7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L',
     atom_array=<class 'biotite.structure.AtomArray'> with shape (2092,),
     pdb_engine='fastpdb',
 ))

In the above example, we wrote a torch Dataset that currently returns a tuple of Structure objects (one representing the “decoy” or sample to use for features and the other representing the ground-truth “label”)

Of course we wouldn’t use these Structure objects directly in our model. In your model, you will have to choose which features to compute and which data structure to use to represent them.

While the PinderDataset and PPIDataset datasets provide examples of feature encodings, below we will adjust the transform and target_transform objects to simply return NumPy NDArray objects as a default. Note: you can’t pass Structure objects to torch DataLoader, you must first convert them into array-like structures supported by the default torch collate_fn. You can implement your own collate functions to adjust this behavior.

def default_transform(structure: Structure) -> NDArray[np.double]:
    return structure.coords
test_data = CustomPinderDataset(split="test", transform=default_transform, target_transform=default_transform)
test_data[0]
(array([[-19.331, -10.568,  -4.591],
        [-18.082, -10.305,  -3.883],
        [-16.867, -10.692,  -4.727],
        ...,
        [ 13.363,  11.96 ,  -6.092],
        [ 15.081,   9.028,  -5.759],
        [ 12.459,  10.17 ,  -6.767]], dtype=float32),
 array([[-13.215324 , -11.076902 ,  15.214827 ],
        [-14.133494 , -10.301851 ,  14.38616  ],
        [-13.614427 ,  -8.882866 ,  14.150835 ],
        ...,
        [  6.8049655, -15.96424  ,  20.159506 ],
        [  9.54528  , -16.992254 ,  18.400843 ],
        [  7.021378 , -15.329907 ,  18.152586 ]], dtype=float32))

Implementing diversity sampling#

Here, we provide an example of how one might use torch.utils.data.WeightedRandomSampler. However, users are free to sample diversity any way they see fit. For this example, we are going to sample diversity inversely proportional to pinder cluster population size.

from torch.utils.data import WeightedRandomSampler

def inverse_cluster_size_sampler(dataset: PinderDataset, replacement: bool = True):
    index = dataset.index
    cluster_counts = (
        index["cluster_id"].value_counts().rename("cluster_count")
    )
    index = index.merge(
        cluster_counts, left_on="cluster_id", right_index=True
    )
    # undersample large clusters
    cluster_weights = 1.0 / torch.tensor(index.cluster_count.values)
    return WeightedRandomSampler(
        weights=cluster_weights,
        num_samples=len(
            cluster_counts
        ),
        replacement=replacement,
    )

sampler = inverse_cluster_size_sampler(
    test_data,
    replacement=True,
)
sampler
<torch.utils.data.sampler.WeightedRandomSampler at 0x7f697dc7e200>

Defining the dataloader#

Now that we have implemented a dataset and sampling function, we can tie everything together to implement the DataLoader.

from torch.utils.data import DataLoader

test_dataloader = DataLoader(
    test_data, 
    batch_size=1, 
    # Mutually exclusive with sampler
    shuffle=False, 
    sampler=sampler,
)
test_features, test_labels = next(iter(test_dataloader))
test_features, test_labels
2024-11-15 12:13:55,075 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
(tensor([[[ -0.7287,  16.4913,   9.8068],
          [ -1.6878,  16.1829,  10.8788],
          [ -2.9926,  16.9600,  10.6499],
          ...,
          [ -3.8695, -10.1312,  23.1116],
          [ -1.7981,  -9.4029,  23.5338],
          [ -2.8393,  -9.4165,  22.7493]]]),
 tensor([[[-0.9892, 19.1343,  8.9643],
          [-1.5965, 18.3829, 10.0574],
          [-3.1128, 18.5439, 10.0904],
          ...,
          [-7.1361, -8.3044, 11.7833],
          [-8.0210, -9.1269, 12.3305],
          [-7.5183, -7.4610, 10.8316]]]))

Putting it all together, we can now get a train/val/test dataloader as such:

from typing import Any, Callable


def get_loader(
    dataset: CustomPinderDataset,
    sampler: torch.utils.data.Sampler | None = inverse_cluster_size_sampler,
    batch_size: int = 2,
    # shuffle is mutually exclusive with sampler
    shuffle: bool = False,
    num_workers: int = 0,
    collate_fn: Callable[[list[tuple[NDArray[np.double], NDArray[np.double]]]], tuple[torch.Tensor, torch.Tensor]] | None = None,
    **kwargs: Any,
) -> "DataLoader[CustomPinderDataset]":
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        sampler=sampler,
        collate_fn=collate_fn,
        **kwargs,
    )


train_data = CustomPinderDataset(
    split="train", 
    structure_filters=[filters.MinAtomTypesFilter()], 
    metadata_filter="(buried_sasa >= 500)",
    transform=default_transform, 
    target_transform=default_transform,
)
train_dataloader = get_loader(
    train_data, 
    sampler=inverse_cluster_size_sampler(train_data),
    batch_size=1,
)
train_dataloader
<torch.utils.data.dataloader.DataLoader at 0x7f69817f7e80>
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels
2024-11-15 12:14:00,725 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
(tensor([[[161.0880, 121.9010, 154.0680],
          [159.7830, 122.5170, 153.9220],
          [159.8500, 124.0310, 153.9140],
          ...,
          [142.9230, 142.9120,  73.1190],
          [142.6480, 142.3420,  74.4990],
          [142.6460, 140.8530,  74.4910]]]),
 tensor([[[161.0880, 121.9010, 154.0680],
          [159.7830, 122.5170, 153.9220],
          [159.8500, 124.0310, 153.9140],
          ...,
          [142.9230, 142.9120,  73.1190],
          [142.6480, 142.3420,  74.4990],
          [142.6460, 140.8530,  74.4910]]]))

Using a larger batch size (collate_fn)#

What if we want to use a larger batch size?

By default, the collate_fn used by the DataLoader is torch.utils.data._utils.collate.default_collate which expects to be able to stack tensors via torch.stack.

Since different pinder systems have a differing number of atomic coordinates, using a batch size greater than 1 will cause this function to raise a RuntimeError with a message like: RuntimeError: stack expects each tensor to be equal size, but got [688, 3] at entry 0 and [1391, 3] at entry 1

We can leverage existing pinder utilities to adapt our default collate_fn to pad tensor dimensions with dummy values so that they can be stacked.

from pinder.core.loader.dataset import pad_and_stack

help(pad_and_stack)
Help on function pad_and_stack in module pinder.core.loader.dataset:

pad_and_stack(tensors: 'list[Tensor]', dim: 'int' = 0, dims_to_pad: 'list[int] | None' = None, value: 'int | float | None' = None) -> 'Tensor'
    Pads a list of tensors to the maximum length observed along each dimension and then stacks them along a new dimension (given by `dim`).
    
    Parameters:
        tensors (list[Tensor]): A list of tensors to pad and stack
        dim (int): The new dimension to stack along.
        dims_to_pad (list[int] | None): The dimensions to pad
        value (int | float | None, optional): The value to pad with, by default None
    
    Returns:
        Tensor: The padded and stacked tensor. Below are examples of input and output shapes
            Example 1: Sequence features (although redundant with torch.rnn.utils.pad_sequence)
                input: [(2,), (7,)], dim: 0
                output: (2, 7)
            Example 2: Pair features (e.g., pairwise coordinates)
                input: [(4, 4, 3), (7, 7, 3)], dim: 0
                output: (2, 7, 7, 3)
def collate_coordinates(batch, coords_pad_value: int = -100):
    feature_coords = []
    target_coords = []
    for x in batch:
        feat, target = x
        if isinstance(feat, np.ndarray):
            feat = torch.tensor(feat, dtype=torch.float32)
        if isinstance(target, np.ndarray):
            target = torch.tensor(target, dtype=torch.float32)
        feature_coords.append(feat)
        target_coords.append(target)

    feature_coords = pad_and_stack(feature_coords, dim=0, value=coords_pad_value)    
    target_coords = pad_and_stack(target_coords, dim=0, value=coords_pad_value)    
    return feature_coords, target_coords


train_dataloader = get_loader(
    train_data, 
    sampler=inverse_cluster_size_sampler(train_data),
    collate_fn=collate_coordinates,
    batch_size=2,
)
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels
2024-11-15 12:14:01,733 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:14:13,286 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
(tensor([[[ 246.8259,  207.0484,  365.5238],
          [ 246.7659,  208.4906,  365.3517],
          [ 247.3034,  208.7366,  363.9514],
          ...,
          [ 197.3790,  222.5565,  382.4614],
          [ 196.3958,  221.6868,  381.6732],
          [ 196.9519,  221.3230,  380.3475]],
 
         [[ 367.6263,  324.4850,  301.8851],
          [ 367.0830,  325.2843,  302.9723],
          [ 367.2429,  326.7630,  302.6137],
          ...,
          [-100.0000, -100.0000, -100.0000],
          [-100.0000, -100.0000, -100.0000],
          [-100.0000, -100.0000, -100.0000]]]),
 tensor([[[ 251.3640,  191.2350,  360.7730],
          [ 251.2140,  190.5020,  359.5220],
          [ 251.7570,  191.3220,  358.3540],
          ...,
          [ 203.7410,  227.2750,  385.3440],
          [ 204.1620,  227.1150,  383.8930],
          [ 204.7830,  225.7870,  383.6390]],
 
         [[ 367.8460,  324.9710,  300.2600],
          [ 367.1610,  325.4500,  301.4540],
          [ 367.0780,  326.9710,  301.4500],
          ...,
          [-100.0000, -100.0000, -100.0000],
          [-100.0000, -100.0000, -100.0000],
          [-100.0000, -100.0000, -100.0000]]]))