Pinder loader#
The goal of this tutorial is to provide some hands-on examples of how one can leverage the pinder
dataset in their ML workflow. Specifically, we will illustrate how you can use various utilities provided by pinder
to write your own data pipeline.
While this tutorial will not go into details about how to write your own model, it will cover the basic groundwork necessary to interface with structures in pinder and the associated splits and metadata. You will of course want to implement your own featurization pipelines, data representations, etc. but the hope is that this tutorial clarifies how to access the data and make use of it in an ML framework.
Before proceeding with this tutorial section, you may find it helpful to review the existing tutorials available in pinder
.
Specifcially, the tutorials covering:
Accessing and loading data for training#
In order to access the train and val splits for PINDER, please refer to the pinder documentation
Once you have downloaded the pinder dataset, either via the pinder
package or directly through gsutil
, you will have all of the necessary files for training.
To get a list of those systems and their split labels, refer to the pinder
index.
We will start by looking at the most basic way to load items from the training and validation set: via PinderSystem
objects
Recap: PinderSystem and Structure classes#
import torch
from pinder.core import get_index, PinderSystem
def get_system(system_id: str) -> PinderSystem:
return PinderSystem(system_id)
index = get_index()
train = index[index.split == "train"].copy()
system = get_system(train.id.iloc[0])
system
2024-11-15 12:09:52,578 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
PinderSystem(
entry = IndexEntry(
(
'split',
'train',
),
(
'id',
'8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
),
(
'pdb_id',
'8phr',
),
(
'cluster_id',
'cluster_24559_24559',
),
(
'cluster_id_R',
'cluster_24559',
),
(
'cluster_id_L',
'cluster_24559',
),
(
'pinder_s',
False,
),
(
'pinder_xl',
False,
),
(
'pinder_af2',
False,
),
(
'uniprot_R',
'UNDEFINED',
),
(
'uniprot_L',
'UNDEFINED',
),
(
'holo_R_pdb',
'8phr__X4_UNDEFINED-R.pdb',
),
(
'holo_L_pdb',
'8phr__W4_UNDEFINED-L.pdb',
),
(
'predicted_R_pdb',
'',
),
(
'predicted_L_pdb',
'',
),
(
'apo_R_pdb',
'',
),
(
'apo_L_pdb',
'',
),
(
'apo_R_pdbs',
'',
),
(
'apo_L_pdbs',
'',
),
(
'holo_R',
True,
),
(
'holo_L',
True,
),
(
'predicted_R',
False,
),
(
'predicted_L',
False,
),
(
'apo_R',
False,
),
(
'apo_L',
False,
),
(
'apo_R_quality',
'',
),
(
'apo_L_quality',
'',
),
(
'chain1_neff',
10.78125,
),
(
'chain2_neff',
11.1171875,
),
(
'chain_R',
'X4',
),
(
'chain_L',
'W4',
),
(
'contains_antibody',
False,
),
(
'contains_antigen',
False,
),
(
'contains_enzyme',
False,
),
)
native=Structure(
filepath=/home/runner/.local/share/pinder/2024-02/pdbs/8phr__X4_UNDEFINED--8phr__W4_UNDEFINED.pdb,
uniprot_map=None,
pinder_id='8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
atom_array=<class 'biotite.structure.AtomArray'> with shape (2556,),
pdb_engine='fastpdb',
)
holo_receptor=Structure(
filepath=/home/runner/.local/share/pinder/2024-02/pdbs/8phr__X4_UNDEFINED-R.pdb,
uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/8phr__X4_UNDEFINED-R.parquet,
pinder_id='8phr__X4_UNDEFINED-R',
atom_array=<class 'biotite.structure.AtomArray'> with shape (1358,),
pdb_engine='fastpdb',
)
holo_ligand=Structure(
filepath=/home/runner/.local/share/pinder/2024-02/pdbs/8phr__W4_UNDEFINED-L.pdb,
uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/8phr__W4_UNDEFINED-L.parquet,
pinder_id='8phr__W4_UNDEFINED-L',
atom_array=<class 'biotite.structure.AtomArray'> with shape (1198,),
pdb_engine='fastpdb',
)
apo_receptor=None
apo_ligand=None
pred_receptor=None
pred_ligand=None
)
Notice the printed PinderSystem
object has the following properties:
native
- the ground-truth dimer complexholo_receptor
- the receptor chain (monomer) from the ground-truth complexholo_ligand
- the ligand chain (monomer) from the ground-truth complexapo_receptor
- the canonical apo chain (monomer) paired to the receptor chainapo_ligand
- the canonical apo chain (monomer) paired to the ligand chainpred_receptor
- the AlphaFold2 predicted monomer paired to the receptor chainpred_ligand
- the AlphaFold2 predicted monomer paired to the ligand chain
These properties are pointers to Structure
objects. The Structure
object provides the most direct mode of access to structures and associated properties.
Note: not all systems have an apo and/or predicted structure for all chains of the ground-truth dimer complex!
As was the case in the example above, when the alternative monomers are not available, the property will have a value of None
.
You can determine which systems have which alternative monomer pairings a priori by looking at the boolean columns in the index apo_R
and apo_L
for the apo receptor and ligand, and predicted_R
and predicted_L
for the predicted receptor and ligand, respectively.
For instance, we can load a different system that does have apo receptor and ligand as such:
apo_system = get_system(train.query('apo_R and apo_L').id.iloc[0])
receptor = apo_system.apo_receptor
ligand = apo_system.apo_ligand
receptor, ligand
2024-11-15 12:09:53,554 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=15
(Structure(
filepath=/home/runner/.local/share/pinder/2024-02/pdbs/3wdb__A1_P9WPC9.pdb,
uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/3wdb__A1_P9WPC9.parquet,
pinder_id='3wdb__A1_P9WPC9',
atom_array=<class 'biotite.structure.AtomArray'> with shape (1144,),
pdb_engine='fastpdb',
),
Structure(
filepath=/home/runner/.local/share/pinder/2024-02/pdbs/6ucr__A1_P9WPC9.pdb,
uniprot_map=/home/runner/.local/share/pinder/2024-02/mappings/6ucr__A1_P9WPC9.parquet,
pinder_id='6ucr__A1_P9WPC9',
atom_array=<class 'biotite.structure.AtomArray'> with shape (1193,),
pdb_engine='fastpdb',
))
We can now access e.g. the sequence and the coordinates of the structures via the Structure
objects:
receptor.sequence
'PLGSMFERFTDRARRVVVLAQEEARMLNHNYIGTEHILLGLIHEGEGVAAKSLESLGISLEGVRSQVEEIIGQGQQAPSGHIPFTPRAKKVLELSLREALQLGHNYIGTEHILLGLIREGEGVAAQVLVKLGAELTRVRQQVIQLLSGY'
receptor.coords[0:5]
array([[-12.982, -17.271, -11.271],
[-14.36 , -17.069, -11.749],
[-15.261, -16.373, -10.703],
[-15.461, -15.161, -10.801],
[-14.842, -18.494, -12.077]], dtype=float32)
We can always access the underyling biotite AtomArray via the Structure.atom_array
property:
receptor.atom_array[0:5]
array([
Atom(np.array([-12.982, -17.271, -11.271], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="N", element="N", b_factor=0.0),
Atom(np.array([-14.36 , -17.069, -11.749], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="CA", element="C", b_factor=0.0),
Atom(np.array([-15.261, -16.373, -10.703], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="C", element="C", b_factor=0.0),
Atom(np.array([-15.461, -15.161, -10.801], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="O", element="O", b_factor=0.0),
Atom(np.array([-14.842, -18.494, -12.077], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="CB", element="C", b_factor=0.0)
])
For a more comprehensive overview of all of the Structure
class properties, refer to the pinder system tutorial.
Using the PinderLoader to load, filter and transform systems#
While the PinderSystem
object provides a self-contained access to structures associated with a dimer system, the PinderLoader
provides a base abstraction for how to iterate over systems, apply optional filters and/or transforms, and return the systems as an iterator. This construct is covered in a different tutorial tutorial.
Using the PinderLoader
is not necessary to load systems in your own framework. It is simply one of the provided mechanisms if you find it useful.
Pinder loader brings together filters, transforms and writers to create a generic PinderSystem
iterator. It takes either a split name or a list of system IDs as input and can be used to sample alternative monomers to form dimer complexes to serve as e.g. features.
Loading a specific split#
Note: only the test dataset has a subset defined (pinder_s, pinder_xl, pinder_af2
)
For train and val, you could just do:
train_loader = PinderLoader(split="train")
val_loader = PinderLoader(split="val")
import torch
from pinder.core import PinderLoader
from pinder.core.loader import filters
base_filters = [
filters.FilterByMissingHolo(),
filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
filters.FilterSubByAtomTypes(min_atom_types=4),
filters.FilterByHoloOverlap(min_overlap=5),
filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
filters.FilterSubRmsds(rmsd_cutoff=7.5),
filters.FilterDetachedSub(radius=12, max_components=2),
]
loader = PinderLoader(
split="test",
subset="pinder_af2",
monomer_priority="holo",
base_filters = base_filters,
sub_filters = sub_filters
)
loader
PinderLoader(split=test, monomers=holo, systems=180)
len(loader)
180
data = loader[0]
print(f"Data is a {type(data)}")
system, feature_complex, target_complex = data
type(system), type(feature_complex), type(target_complex)
2024-11-15 12:09:58,528 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
Data is a <class 'tuple'>
(pinder.core.index.system.PinderSystem,
pinder.core.loader.structure.Structure,
pinder.core.loader.structure.Structure)
# You can also use it as an iterator
from tqdm import tqdm
loaded_ids = []
for (system, feature_complex, target_complex) in tqdm(loader):
loaded_ids.append(system.entry.id)
0%| | 0/180 [00:00<?, ?it/s]
1%| | 1/180 [00:00<00:50, 3.58it/s]
2024-11-15 12:10:00,180 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
1%| | 2/180 [00:01<02:08, 1.39it/s]
2024-11-15 12:10:01,199 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
2%|▏ | 3/180 [00:02<02:40, 1.10it/s]
2024-11-15 12:10:02,333 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2%|▏ | 4/180 [00:04<03:46, 1.29s/it]
2024-11-15 12:10:04,197 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
3%|▎ | 5/180 [00:05<03:52, 1.33s/it]
2024-11-15 12:10:05,596 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=12
3%|▎ | 6/180 [00:08<04:52, 1.68s/it]
2024-11-15 12:10:07,964 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
4%|▍ | 7/180 [00:09<04:24, 1.53s/it]
2024-11-15 12:10:09,188 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=11
4%|▍ | 8/180 [00:11<04:57, 1.73s/it]
5%|▌ | 9/180 [00:11<03:40, 1.29s/it]
6%|▌ | 10/180 [00:12<03:03, 1.08s/it]
2024-11-15 12:10:12,275 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
6%|▌ | 11/180 [00:14<03:55, 1.39s/it]
7%|▋ | 12/180 [00:14<02:57, 1.05s/it]
2024-11-15 12:10:14,658 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
7%|▋ | 13/180 [00:16<03:16, 1.18s/it]
8%|▊ | 14/180 [00:16<02:31, 1.09it/s]
2024-11-15 12:10:16,424 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
8%|▊ | 15/180 [00:17<02:51, 1.04s/it]
2024-11-15 12:10:17,751 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
9%|▉ | 16/180 [00:19<03:13, 1.18s/it]
2024-11-15 12:10:19,263 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
9%|▉ | 17/180 [00:21<03:46, 1.39s/it]
2024-11-15 12:10:21,146 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
10%|█ | 18/180 [00:22<03:35, 1.33s/it]
2024-11-15 12:10:22,328 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
11%|█ | 19/180 [00:24<03:48, 1.42s/it]
2024-11-15 12:10:23,963 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=9
11%|█ | 20/180 [00:26<04:15, 1.60s/it]
2024-11-15 12:10:25,978 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
12%|█▏ | 21/180 [00:27<04:22, 1.65s/it]
2024-11-15 12:10:27,743 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
12%|█▏ | 22/180 [00:29<04:42, 1.79s/it]
2024-11-15 12:10:29,850 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
13%|█▎ | 23/180 [00:31<04:27, 1.70s/it]
2024-11-15 12:10:31,356 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=9
13%|█▎ | 24/180 [00:32<04:05, 1.57s/it]
2024-11-15 12:10:32,625 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
14%|█▍ | 25/180 [00:33<03:29, 1.35s/it]
2024-11-15 12:10:33,462 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
14%|█▍ | 26/180 [00:34<03:07, 1.21s/it]
2024-11-15 12:10:34,357 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
15%|█▌ | 27/180 [00:35<02:49, 1.11s/it]
2024-11-15 12:10:35,220 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
16%|█▌ | 28/180 [00:36<02:54, 1.15s/it]
2024-11-15 12:10:36,466 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
16%|█▌ | 29/180 [00:37<02:35, 1.03s/it]
2024-11-15 12:10:37,220 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
17%|█▋ | 30/180 [00:38<02:20, 1.06it/s]
2024-11-15 12:10:37,950 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
17%|█▋ | 31/180 [00:39<02:41, 1.08s/it]
18%|█▊ | 32/180 [00:39<02:09, 1.15it/s]
2024-11-15 12:10:39,741 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
18%|█▊ | 33/180 [00:40<02:09, 1.14it/s]
2024-11-15 12:10:40,635 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
19%|█▉ | 34/180 [00:41<02:12, 1.10it/s]
2024-11-15 12:10:41,629 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
19%|█▉ | 35/180 [00:42<02:02, 1.19it/s]
2024-11-15 12:10:42,303 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
20%|██ | 36/180 [00:43<02:21, 1.02it/s]
2024-11-15 12:10:43,603 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
21%|██ | 37/180 [00:44<02:31, 1.06s/it]
2024-11-15 12:10:44,850 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
21%|██ | 38/180 [00:48<03:57, 1.67s/it]
22%|██▏ | 39/180 [00:50<04:08, 1.76s/it]
22%|██▏ | 40/180 [00:50<03:13, 1.38s/it]
2024-11-15 12:10:50,417 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
23%|██▎ | 41/180 [00:52<03:33, 1.54s/it]
2024-11-15 12:10:52,319 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
23%|██▎ | 42/180 [00:53<03:19, 1.44s/it]
2024-11-15 12:10:53,544 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=11
24%|██▍ | 43/180 [00:55<03:21, 1.47s/it]
2024-11-15 12:10:55,079 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
24%|██▍ | 44/180 [00:56<03:30, 1.55s/it]
2024-11-15 12:10:56,819 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
25%|██▌ | 45/180 [00:58<03:27, 1.54s/it]
26%|██▌ | 46/180 [00:58<02:40, 1.20s/it]
2024-11-15 12:10:58,720 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
26%|██▌ | 47/180 [01:00<02:40, 1.20s/it]
2024-11-15 12:10:59,948 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
27%|██▋ | 48/180 [01:00<02:25, 1.10s/it]
2024-11-15 12:11:00,807 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
27%|██▋ | 49/180 [01:01<02:05, 1.04it/s]
2024-11-15 12:11:01,426 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
28%|██▊ | 50/180 [01:02<01:59, 1.09it/s]
28%|██▊ | 51/180 [01:02<01:37, 1.33it/s]
2024-11-15 12:11:02,622 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
29%|██▉ | 52/180 [01:03<01:49, 1.17it/s]
29%|██▉ | 53/180 [01:04<01:45, 1.20it/s]
2024-11-15 12:11:04,499 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
30%|███ | 54/180 [01:05<01:40, 1.25it/s]
2024-11-15 12:11:05,216 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
31%|███ | 55/180 [01:06<01:36, 1.29it/s]
2024-11-15 12:11:05,933 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
31%|███ | 56/180 [01:07<02:07, 1.03s/it]
2024-11-15 12:11:07,555 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
32%|███▏ | 57/180 [01:08<02:04, 1.01s/it]
2024-11-15 12:11:08,517 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
32%|███▏ | 58/180 [01:10<02:33, 1.25s/it]
2024-11-15 12:11:10,346 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
33%|███▎ | 59/180 [01:11<02:28, 1.23s/it]
2024-11-15 12:11:11,524 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
33%|███▎ | 60/180 [01:12<02:14, 1.12s/it]
2024-11-15 12:11:12,384 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
34%|███▍ | 61/180 [01:13<01:55, 1.03it/s]
34%|███▍ | 62/180 [01:13<01:49, 1.08it/s]
2024-11-15 12:11:13,821 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
35%|███▌ | 63/180 [01:14<01:51, 1.05it/s]
2024-11-15 12:11:14,838 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
36%|███▌ | 64/180 [01:15<01:50, 1.05it/s]
2024-11-15 12:11:15,783 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
36%|███▌ | 65/180 [01:18<03:02, 1.59s/it]
2024-11-15 12:11:18,871 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
37%|███▋ | 66/180 [01:20<03:00, 1.58s/it]
2024-11-15 12:11:20,438 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
37%|███▋ | 67/180 [01:21<02:50, 1.51s/it]
2024-11-15 12:11:21,756 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
38%|███▊ | 68/180 [01:22<02:24, 1.29s/it]
2024-11-15 12:11:22,540 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
38%|███▊ | 69/180 [01:23<02:01, 1.10s/it]
2024-11-15 12:11:23,193 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
39%|███▉ | 70/180 [01:24<02:02, 1.11s/it]
39%|███▉ | 71/180 [01:25<01:45, 1.03it/s]
2024-11-15 12:11:24,971 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
40%|████ | 72/180 [01:26<01:50, 1.03s/it]
2024-11-15 12:11:26,130 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
41%|████ | 73/180 [01:27<01:43, 1.03it/s]
2024-11-15 12:11:26,973 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
41%|████ | 74/180 [01:29<02:21, 1.34s/it]
2024-11-15 12:11:29,167 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
42%|████▏ | 75/180 [01:30<02:21, 1.35s/it]
2024-11-15 12:11:30,541 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
42%|████▏ | 76/180 [01:31<02:16, 1.31s/it]
2024-11-15 12:11:31,775 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
43%|████▎ | 77/180 [01:32<02:03, 1.20s/it]
2024-11-15 12:11:32,718 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
43%|████▎ | 78/180 [01:33<01:56, 1.14s/it]
2024-11-15 12:11:33,721 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
44%|████▍ | 79/180 [01:35<02:08, 1.27s/it]
2024-11-15 12:11:35,285 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=9
44%|████▍ | 80/180 [01:37<02:33, 1.54s/it]
45%|████▌ | 81/180 [01:38<02:12, 1.34s/it]
2024-11-15 12:11:38,318 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
46%|████▌ | 82/180 [01:39<02:05, 1.28s/it]
2024-11-15 12:11:39,467 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=6
46%|████▌ | 83/180 [01:40<01:47, 1.11s/it]
2024-11-15 12:11:40,185 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
47%|████▋ | 84/180 [01:41<01:49, 1.14s/it]
2024-11-15 12:11:41,376 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
47%|████▋ | 85/180 [01:42<01:44, 1.10s/it]
2024-11-15 12:11:42,386 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
48%|████▊ | 86/180 [01:43<01:31, 1.02it/s]
2024-11-15 12:11:43,087 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
48%|████▊ | 87/180 [01:43<01:20, 1.16it/s]
49%|████▉ | 88/180 [01:44<01:09, 1.32it/s]
49%|████▉ | 89/180 [01:44<00:57, 1.57it/s]
50%|█████ | 90/180 [01:45<00:50, 1.77it/s]
2024-11-15 12:11:44,937 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
51%|█████ | 91/180 [01:45<00:51, 1.72it/s]
51%|█████ | 92/180 [01:46<00:49, 1.78it/s]
2024-11-15 12:11:46,073 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
52%|█████▏ | 93/180 [01:46<00:53, 1.62it/s]
2024-11-15 12:11:46,819 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
52%|█████▏ | 94/180 [01:48<01:05, 1.32it/s]
2024-11-15 12:11:47,916 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
53%|█████▎ | 95/180 [01:48<01:08, 1.25it/s]
2024-11-15 12:11:48,816 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
53%|█████▎ | 96/180 [01:49<01:06, 1.26it/s]
2024-11-15 12:11:49,582 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
54%|█████▍ | 97/180 [01:50<01:10, 1.18it/s]
2024-11-15 12:11:50,550 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
54%|█████▍ | 98/180 [01:51<01:07, 1.21it/s]
2024-11-15 12:11:51,334 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
55%|█████▌ | 99/180 [01:52<01:08, 1.18it/s]
56%|█████▌ | 100/180 [01:52<00:54, 1.46it/s]
56%|█████▌ | 101/180 [01:53<00:47, 1.67it/s]
2024-11-15 12:11:52,933 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
57%|█████▋ | 102/180 [01:54<00:55, 1.40it/s]
2024-11-15 12:11:53,920 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
57%|█████▋ | 103/180 [01:55<01:01, 1.25it/s]
2024-11-15 12:11:54,917 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
58%|█████▊ | 104/180 [01:55<00:55, 1.36it/s]
2024-11-15 12:11:55,495 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
58%|█████▊ | 105/180 [01:56<01:08, 1.10it/s]
2024-11-15 12:11:56,821 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
59%|█████▉ | 106/180 [01:58<01:15, 1.02s/it]
2024-11-15 12:11:58,111 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=6
59%|█████▉ | 107/180 [01:58<01:09, 1.06it/s]
2024-11-15 12:11:58,877 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
60%|██████ | 108/180 [02:00<01:22, 1.15s/it]
2024-11-15 12:12:00,486 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
61%|██████ | 109/180 [02:01<01:18, 1.10s/it]
2024-11-15 12:12:01,492 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
61%|██████ | 110/180 [02:02<01:09, 1.01it/s]
62%|██████▏ | 111/180 [02:02<00:55, 1.25it/s]
2024-11-15 12:12:02,572 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
62%|██████▏ | 112/180 [02:03<00:59, 1.14it/s]
2024-11-15 12:12:03,637 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
63%|██████▎ | 113/180 [02:04<01:01, 1.08it/s]
2024-11-15 12:12:04,661 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
63%|██████▎ | 114/180 [02:05<01:02, 1.05it/s]
2024-11-15 12:12:05,683 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
64%|██████▍ | 115/180 [02:06<01:02, 1.04it/s]
64%|██████▍ | 116/180 [02:07<00:51, 1.23it/s]
2024-11-15 12:12:07,124 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
65%|██████▌ | 117/180 [02:07<00:46, 1.35it/s]
2024-11-15 12:12:07,700 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
66%|██████▌ | 118/180 [02:08<00:48, 1.27it/s]
2024-11-15 12:12:08,604 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=6
66%|██████▌ | 119/180 [02:09<00:49, 1.22it/s]
2024-11-15 12:12:09,490 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
67%|██████▋ | 120/180 [02:10<00:53, 1.13it/s]
67%|██████▋ | 121/180 [02:11<00:43, 1.35it/s]
2024-11-15 12:12:10,924 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
68%|██████▊ | 122/180 [02:11<00:42, 1.36it/s]
2024-11-15 12:12:11,655 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
68%|██████▊ | 123/180 [02:12<00:39, 1.43it/s]
2024-11-15 12:12:12,266 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
69%|██████▉ | 124/180 [02:13<00:39, 1.42it/s]
2024-11-15 12:12:12,980 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
69%|██████▉ | 125/180 [02:14<00:44, 1.24it/s]
2024-11-15 12:12:14,040 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
70%|███████ | 126/180 [02:15<00:51, 1.04it/s]
2024-11-15 12:12:15,338 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
71%|███████ | 127/180 [02:16<00:45, 1.18it/s]
2024-11-15 12:12:15,944 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
71%|███████ | 128/180 [02:17<00:48, 1.06it/s]
72%|███████▏ | 129/180 [02:18<00:50, 1.02it/s]
72%|███████▏ | 130/180 [02:18<00:39, 1.26it/s]
2024-11-15 12:12:18,519 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
73%|███████▎ | 131/180 [02:19<00:43, 1.14it/s]
2024-11-15 12:12:19,602 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
73%|███████▎ | 132/180 [02:21<00:48, 1.01s/it]
2024-11-15 12:12:20,928 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
74%|███████▍ | 133/180 [02:21<00:43, 1.08it/s]
2024-11-15 12:12:21,658 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
74%|███████▍ | 134/180 [02:22<00:41, 1.11it/s]
2024-11-15 12:12:22,500 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
75%|███████▌ | 135/180 [02:23<00:42, 1.05it/s]
2024-11-15 12:12:23,569 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
76%|███████▌ | 136/180 [02:24<00:45, 1.04s/it]
76%|███████▌ | 137/180 [02:25<00:37, 1.15it/s]
77%|███████▋ | 138/180 [02:25<00:30, 1.39it/s]
2024-11-15 12:12:25,653 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
77%|███████▋ | 139/180 [02:26<00:28, 1.42it/s]
2024-11-15 12:12:26,323 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
78%|███████▊ | 140/180 [02:27<00:34, 1.17it/s]
2024-11-15 12:12:27,533 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
78%|███████▊ | 141/180 [02:29<00:40, 1.04s/it]
2024-11-15 12:12:29,004 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
79%|███████▉ | 142/180 [02:30<00:38, 1.02s/it]
2024-11-15 12:12:29,961 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
79%|███████▉ | 143/180 [02:31<00:38, 1.03s/it]
2024-11-15 12:12:31,040 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
80%|████████ | 144/180 [02:31<00:35, 1.03it/s]
2024-11-15 12:12:31,874 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
81%|████████ | 145/180 [02:32<00:33, 1.05it/s]
81%|████████ | 146/180 [02:33<00:25, 1.31it/s]
2024-11-15 12:12:33,092 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
82%|████████▏ | 147/180 [02:33<00:25, 1.31it/s]
82%|████████▏ | 148/180 [02:34<00:20, 1.55it/s]
2024-11-15 12:12:34,226 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
83%|████████▎ | 149/180 [02:35<00:22, 1.35it/s]
2024-11-15 12:12:35,181 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
83%|████████▎ | 150/180 [02:36<00:23, 1.29it/s]
2024-11-15 12:12:36,040 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
84%|████████▍ | 151/180 [02:36<00:21, 1.36it/s]
2024-11-15 12:12:36,686 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
84%|████████▍ | 152/180 [02:37<00:22, 1.27it/s]
2024-11-15 12:12:37,595 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
85%|████████▌ | 153/180 [02:38<00:20, 1.29it/s]
2024-11-15 12:12:38,478 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=384
86%|████████▌ | 154/180 [03:06<03:50, 8.85s/it]
2024-11-15 12:13:06,029 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
86%|████████▌ | 155/180 [03:07<02:46, 6.66s/it]
2024-11-15 12:13:07,581 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
87%|████████▋ | 156/180 [03:08<01:58, 4.95s/it]
2024-11-15 12:13:08,538 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
87%|████████▋ | 157/180 [03:09<01:26, 3.75s/it]
2024-11-15 12:13:09,493 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
88%|████████▊ | 158/180 [03:10<01:01, 2.82s/it]
2024-11-15 12:13:10,125 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
88%|████████▊ | 159/180 [03:10<00:45, 2.16s/it]
2024-11-15 12:13:10,764 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
89%|████████▉ | 160/180 [03:12<00:39, 1.98s/it]
2024-11-15 12:13:12,335 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
89%|████████▉ | 161/180 [03:13<00:33, 1.76s/it]
2024-11-15 12:13:13,573 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
90%|█████████ | 162/180 [03:14<00:26, 1.45s/it]
2024-11-15 12:13:14,306 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
91%|█████████ | 163/180 [03:15<00:23, 1.35s/it]
2024-11-15 12:13:15,427 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
91%|█████████ | 164/180 [03:16<00:19, 1.19s/it]
92%|█████████▏| 165/180 [03:16<00:13, 1.08it/s]
2024-11-15 12:13:16,557 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
92%|█████████▏| 166/180 [03:17<00:12, 1.15it/s]
2024-11-15 12:13:17,288 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
93%|█████████▎| 167/180 [03:18<00:10, 1.19it/s]
2024-11-15 12:13:18,067 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
93%|█████████▎| 168/180 [03:19<00:10, 1.15it/s]
2024-11-15 12:13:18,998 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
94%|█████████▍| 169/180 [03:20<00:09, 1.12it/s]
2024-11-15 12:13:19,948 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
94%|█████████▍| 170/180 [03:20<00:07, 1.26it/s]
2024-11-15 12:13:20,518 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
95%|█████████▌| 171/180 [03:21<00:07, 1.23it/s]
96%|█████████▌| 172/180 [03:21<00:05, 1.50it/s]
2024-11-15 12:13:21,690 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
96%|█████████▌| 173/180 [03:22<00:05, 1.33it/s]
2024-11-15 12:13:22,638 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
97%|█████████▋| 174/180 [03:24<00:05, 1.02it/s]
2024-11-15 12:13:24,158 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
97%|█████████▋| 175/180 [03:25<00:04, 1.07it/s]
2024-11-15 12:13:24,981 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
98%|█████████▊| 176/180 [03:26<00:03, 1.01it/s]
2024-11-15 12:13:26,092 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
98%|█████████▊| 177/180 [03:27<00:03, 1.05s/it]
2024-11-15 12:13:27,285 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
99%|█████████▉| 178/180 [03:28<00:02, 1.10s/it]
2024-11-15 12:13:28,511 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
99%|█████████▉| 179/180 [03:29<00:01, 1.00s/it]
2024-11-15 12:13:29,282 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
100%|██████████| 180/180 [03:30<00:00, 1.06it/s]
100%|██████████| 180/180 [03:30<00:00, 1.17s/it]
Loading a specific list of systems#
systems = [
"1df0__A1_Q07009--1df0__B1_Q64537",
"117e__A1_P00817--117e__B1_P00817",
]
loader = PinderLoader(
ids=systems,
monomer_priority="holo",
base_filters = base_filters,
sub_filters = sub_filters
)
passing_ids = []
for item in loader:
passing_ids.append(item[0].entry.id)
systems_removed_by_filters = set(systems) - set(passing_ids)
systems_removed_by_filters
2024-11-15 12:13:30,689 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
set()
len(systems) == len(passing_ids)
True
Optional Pinder writer#
Without defining a writer for the PinderLoader
, the loaded systems are available as a tuple of (PinderSystem
, Structure
, Structure
) objects, containing the original PinderSystem
and the sampled feature and target complexes, respectively.
If you want to explicitly write the (potentially transformed) structure objects to a custom location or in a custom format (e.g. PDB, pickle, etc.), you can implement a subclass of PinderWriterBase
.
The default writer implements writing to PDB files (leveraging the Structure.to_pdb
method on the structure objects).
from pinder.core.loader.writer import PinderDefaultWriter
from pathlib import Path
from tempfile import TemporaryDirectory
with TemporaryDirectory() as tmp_dir:
temp_dir = Path(tmp_dir)
loader = PinderLoader(
ids=systems,
monomer_priority="pred",
writer=PinderDefaultWriter(temp_dir)
)
assert set(loader.index.id) == set(systems)
for i, r in loader.index.iterrows():
loaded = loader[i]
pinder_id = r.id
system_dir = loader.writer.output_path / pinder_id
assert system_dir.is_dir()
print(list(system_dir.glob("af_*.pdb")))
[PosixPath('/tmp/tmphvp65592/117e__A1_P00817--117e__B1_P00817/af__P00817.pdb')]
[PosixPath('/tmp/tmphvp65592/1df0__A1_Q07009--1df0__B1_Q64537/af__Q07009.pdb'), PosixPath('/tmp/tmphvp65592/1df0__A1_Q07009--1df0__B1_Q64537/af__Q64537.pdb')]
Constructing torch datasets and dataloaders from pinder systems#
The remaining sections of this tutorial will be for those interested specifically in torch datasets and dataloaders.
Specifically, we will show how to:
Implement a PyTorch
Dataset
to interface with pinder dataInclude apo and predicted monomers in the data pipeline, with an option to target specific monomer types or randomly sample from the available types
Leverage
PinderSystem
and its associated methods to crop apo/predicted monomers to match the ground-truth holo monomersWrite filters and transforms that operate on
Structure
objectsIntegrate annotations in data filtering and featurization
Create example features to use for training (you will of course choose your own features)
Incorporate diversity sampling in the data loader
The pinder.core.loader.dataset
module provides two example implementations of how to integrate the pinder dataset into a torch-based machine learning pipeline.
PinderDataset
: A map-styletorch.utils.data.Dataset
that can be used with torchDataLoader
’s.PPIDataset
: Atorch_geometric.data.Dataset
that can be used with torch-geometricDataLoader
’s. This dataset is designed to be used with thetorch_geometric
package.
Together, the two datasets provide an example implementation of how to abstract away the complexity of loading and processing multiple structures associated with each PinderSystem
by leveraging the following utilities from pinder:
pinder.core.PinderLoader
pinder.core.loader.filters
pinder.core.loader.transforms
The examples cover two different batch data item structures to illustrate two different use-cases:
PinderDataset
: A batch of(target_complex, feature_complex)
pairs, wheretarget_complex
andfeature_complex
aretorch.Tensor
objects representing the atomic coordinates and atom types of the holo and sampled (decoy, holo/apo/pred) complexes, respectively.PPIDataset
: A batch ofPairedPDB
objects, where the receptor and ligand are encoded separately in a heterogeneous graph, viatorch_geometric.data.HeteroData
, holding multiple node and/or edge types in disjunct storage objects.
The remaining sections will be split into:
Using the
PinderDataset
torch datasetUsing the
PPIDataset
torch-geometric datasetHow you could implement your own dataset & dataloader
PinderDataset (torch Dataset)#
The PinderDataset
is an example implementation of a torch.utils.data.Dataset
that represents its data items as a dict containing the following key, value pairs:
target_complex
: The ground-truth holo dimer, represented with a set of default properties encoded asTensor
’sfeature_complex
: The sampled dimer complex, representing “features”, also represented with a set of default properties encoded asTensor
’sid
: The pinder ID for the selected systemtarget_id
: The IDs of the receptor and ligand holo monomers, concatenated into a single ID stringsample_id
: The IDs of the sampled receptor and ligand holo monomers, concatenated into a single ID string. This can be useful for debugging purposes or generally tracking which specific monomers are selected when targeting alternative monomers (more on this shortly)
Each of the target_complex
and feature_complex
values are dictionaries with structural properties encoded by the pinder.core.loader.geodata.structure2tensor
function by default:
atom_coordinates
atom_types
chain_ids
residue_coordinates
residue_types
residue_ids
Note: You can choose to use a different representation by overriding the default values of transform
and target_transform
. The default transform is the structure2tensor_transform defined in pinder.core.loader.dataset.
It simply takes a Structure
object and returns a dictionary with string keys and tensor values.
It leverages the PinderLoader
to apply optional filters and/or transforms, provide an interface for sampling alternative monomers, and exposes transform
and target_transform
arguments used by the torch Dataset API.
For more details on the torch Dataset APIs, please refer to the tutorials.
from pinder.core.loader import filters, transforms
from pinder.core.loader.dataset import PinderDataset, structure2tensor_transform
base_filters = [
filters.FilterByMissingHolo(),
filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
filters.FilterSubByAtomTypes(min_atom_types=4),
filters.FilterByHoloOverlap(min_overlap=5),
filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
filters.FilterSubRmsds(rmsd_cutoff=7.5),
filters.FilterDetachedSub(radius=12, max_components=2),
]
# We can include Structure-level transforms (and filters) which will operate on the target and/or feature complexes
structure_transforms = [
transforms.SelectAtomTypes(atom_types=["CA", "N", "C", "O"])
]
train_dataset = PinderDataset(
split="train",
# We can leverage holo, apo, pred, random and random_mixed monomer sampling strategies
monomer_priority="random_mixed",
base_filters = base_filters,
sub_filters = sub_filters,
# Apply to the target (ground-truth) complex
structure_transforms_target=structure_transforms,
# Apply to the feature complex
structure_transforms_feature=structure_transforms,
# This is the default transform if not specified
transform=structure2tensor_transform,
target_transform=structure2tensor_transform,
)
assert len(train_dataset) == len(get_index().query('split == "train"'))
train_dataset
<pinder.core.loader.dataset.PinderDataset at 0x7f697dd23ac0>
Sampling alternative monomers#
The monomer_priority
argument can be used to target different mixes of bound and unbound monomers to use for creating the decoy/feature complex.
The allowed values for monomer_priority
are “apo”, “holo”, “pred”, “random” or “random_mixed”.
When monomer_priority
is set to one of the available monomer types (holo, apo, pred), the same monomer type will be selected for both receptor and ligand.
When the monomer priority is “random”, a random monomer type will be selected from the set of monomer types available for both the receptor and ligand. This option ensures the same type of monomer is used for the receptor and ligand.
When the monomer priority is “random_mixed”, a random monomer type will be selected for each of receptor and ligand, separately.
Enabling the fallback_to_holo
option (default) will enable silent fallback to holo when the monomer_priority
is set to one of apo or pred, but the corresponding monomer is not available for the dimer.
This is useful when only one of receptor or ligand has an unbound monomer, but you wish to include apo or predicted structures in your workflow.
If fallback_to_holo
is disabled, an error will be raised when the monomer_priority
is set to one of apo or pred, but the corresponding monomer is not available for the dimer.
By default, when apo monomers are selected, the “canonical” apo monomer is used. Although a single canonical apo monomer should be used for eval, pinder provides multiple apo monomers paired to a single holo monomer (when available). In order to include these non-canonical/alternative monomers, you can specify use_canonical_apo=False
when constructing the PinderLoader
or PinderDataset
objects.
data_item = train_dataset[0]
data_item
{'target_complex': {'atom_types': tensor([[0.],
[1.],
[2.],
...,
[1.],
[2.],
[3.]]),
'element_types': tensor([[3.],
[0.],
[0.],
...,
[0.],
[0.],
[2.]]),
'residue_types': tensor([[16.],
[16.],
[16.],
...,
[ 0.],
[ 0.],
[ 0.]]),
'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],
[132.6810, 428.2520, 163.1550],
[133.5150, 428.6750, 161.9500],
...,
[177.7620, 463.8650, 166.9020],
[177.4130, 465.0800, 167.7550],
[176.8000, 464.9490, 168.8150]]),
'residue_coordinates': tensor([[132.6810, 428.2520, 163.1550],
[133.5560, 429.2490, 159.5910],
[133.8750, 432.8980, 160.6290],
[136.1110, 431.8050, 163.5130],
[138.3420, 429.8920, 161.0830],
[138.5230, 432.9090, 158.7600],
[139.4710, 435.1730, 161.6750],
[142.1200, 432.6310, 162.6740],
[143.5890, 432.7500, 159.1600],
[143.6190, 436.5600, 159.1540],
[145.2600, 436.8040, 162.5830],
[147.8380, 434.1320, 161.7330],
[148.7390, 435.8920, 158.4790],
[149.0640, 439.2630, 160.2240],
[151.2220, 437.7670, 162.9840],
[153.4710, 435.8030, 160.6120],
[153.9420, 438.9180, 158.4710],
[156.1140, 440.2660, 161.3150],
[158.3330, 437.1860, 161.4250],
[161.8090, 436.4120, 160.1100],
[161.1120, 432.8370, 158.9350],
[157.4020, 432.0710, 159.3930],
[156.5290, 428.3870, 159.0310],
[154.0400, 428.1160, 156.1500],
[153.3520, 424.6590, 154.7510],
[152.8370, 424.8210, 150.9870],
[150.9700, 422.4160, 148.7320],
[153.5900, 422.6860, 145.9670],
[157.3050, 423.3410, 145.6600],
[158.4860, 426.6730, 144.2700],
[160.7730, 427.0880, 141.2590],
[163.9630, 429.0530, 141.9600],
[166.4230, 427.4330, 139.5550],
[167.6170, 430.8330, 138.3250],
[170.5180, 431.9210, 140.5320],
[169.7800, 435.6680, 140.3160],
[167.6790, 437.8790, 142.6110],
[167.3350, 435.1910, 145.2890],
[168.9090, 437.3000, 148.0580],
[166.1860, 439.9440, 148.5750],
[164.1010, 438.1300, 151.1890],
[163.0190, 438.8690, 154.7450],
[164.6750, 435.6560, 155.9860],
[168.1170, 436.6550, 154.6330],
[169.8040, 437.0930, 158.0010],
[173.2680, 438.4850, 158.5750],
[174.9270, 440.8490, 156.1460],
[175.9170, 440.8010, 152.5000],
[179.6060, 441.0910, 153.4330],
[180.8000, 439.9830, 156.8710],
[184.2540, 439.9980, 158.4580],
[185.5840, 437.2970, 160.7670],
[186.4140, 439.9000, 163.4480],
[182.9430, 441.4570, 163.7500],
[181.4930, 442.0910, 167.2100],
[178.0780, 440.5360, 167.9060],
[176.1650, 440.9290, 171.1790],
[173.0530, 438.9530, 172.1310],
[170.3990, 440.6990, 174.2170],
[166.7880, 440.0710, 175.2260],
[164.4480, 442.1010, 173.0340],
[160.9780, 443.6110, 173.3810],
[158.8330, 443.9000, 170.2730],
[160.1820, 442.9360, 166.8360],
[163.5010, 444.7230, 166.1630],
[164.6310, 442.4290, 163.3160],
[166.6330, 444.2560, 160.6240],
[166.1700, 447.5210, 162.5220],
[168.1350, 449.9450, 164.6440],
[168.0040, 449.0470, 168.3240],
[168.4210, 451.0540, 171.5220],
[169.0440, 450.0360, 175.1300],
[166.1650, 450.3680, 177.6100],
[166.7800, 449.8300, 181.3340],
[163.8480, 448.9150, 183.5760],
[163.3500, 443.9770, 187.9750],
[165.0240, 442.4930, 184.8870],
[168.1120, 443.0980, 182.7810],
[168.1280, 445.9210, 180.2050],
[166.7230, 445.0130, 176.8030],
[167.0690, 446.1330, 173.1930],
[164.0790, 447.8390, 171.5780],
[163.4630, 449.0630, 168.0350],
[164.4160, 452.7320, 168.0130],
[166.7460, 455.3190, 166.5790],
[167.5460, 459.0090, 166.8440],
[170.2010, 459.9620, 169.3810],
[169.5150, 456.8220, 171.4570],
[170.7440, 454.3880, 168.7860],
[172.8900, 451.6810, 170.3850],
[173.2420, 449.3000, 167.4490],
[171.4700, 447.3670, 164.7150],
[169.7840, 443.9650, 164.9670],
[170.8880, 441.4120, 162.3680],
[169.3320, 438.1640, 163.6510],
[166.7170, 436.9060, 166.1030],
[166.2830, 433.5290, 167.7900],
[162.5490, 433.4260, 168.5040],
[162.6440, 430.3540, 170.7620],
[164.9490, 432.1120, 173.2350],
[163.6530, 435.5760, 172.2210],
[167.2540, 436.7450, 171.8300],
[168.3710, 439.3650, 169.3060],
[171.8910, 439.6680, 167.8930],
[173.1360, 443.2500, 167.5870],
[176.0990, 444.7950, 165.8030],
[177.1670, 447.8660, 167.8270],
[177.4160, 451.2430, 166.1360],
[181.1790, 451.0810, 166.7480],
[181.3560, 448.4350, 163.9980],
[180.6020, 448.8920, 160.3060],
[177.8440, 446.8990, 158.6170],
[176.7540, 446.3160, 155.0150],
[173.3820, 444.6900, 154.4770],
[169.6430, 444.9660, 153.9760],
[167.6080, 446.9450, 156.5140],
[163.9630, 447.9270, 156.8800],
[163.1520, 451.2160, 155.1430],
[160.3870, 453.6570, 156.0420],
[158.9330, 453.6920, 152.5150],
[159.8050, 453.2460, 148.8360],
[161.2560, 456.7540, 148.4440],
[164.9420, 455.8330, 148.8520],
[167.3780, 455.9500, 145.9260],
[171.0380, 455.0610, 145.5060],
[173.5830, 457.5340, 146.9270],
[171.2360, 458.9430, 149.5740],
[171.8150, 459.8640, 153.2130],
[169.5910, 457.9090, 155.6000],
[168.4960, 458.3960, 159.2100],
[167.0870, 455.9020, 161.7110],
[163.6430, 457.0390, 162.8610],
[161.8700, 456.3550, 166.1660],
[160.7490, 452.9260, 164.8950],
[164.2440, 451.8750, 163.8110],
[163.4490, 452.3040, 160.1090],
[165.6820, 454.0140, 157.5590],
[164.2940, 457.2310, 156.0670],
[165.6810, 459.7290, 153.5780],
[167.3870, 462.7790, 155.0760],
[170.2060, 465.8340, 161.8360],
[172.1620, 462.7680, 160.6890],
[174.0280, 461.7040, 163.8180],
[175.2110, 458.5660, 161.9930],
[176.5410, 458.2410, 158.4450],
[174.5550, 455.7530, 156.3540],
[174.3580, 455.6030, 152.5600],
[171.8590, 453.6260, 150.4910],
[173.6300, 451.2730, 148.0810],
[170.4820, 450.1750, 146.2000],
[166.8520, 451.1980, 145.7810],
[164.0730, 450.1880, 148.1490],
[162.6710, 446.7750, 147.1980],
[159.1420, 445.8080, 148.2240],
[158.3160, 442.3640, 149.5950],
[155.0300, 443.1470, 151.4050],
[152.6800, 446.0960, 151.8500],
[155.0610, 447.4890, 154.4910],
[158.2010, 445.3460, 154.0190],
[160.8090, 447.5340, 152.2870],
[164.4130, 446.3150, 152.1360],
[167.3360, 448.5380, 151.1380],
[171.0310, 447.6630, 150.9500],
[172.8160, 450.1440, 153.2260],
[176.3550, 450.4540, 154.5850],
[176.6310, 452.0780, 158.0240],
[180.1600, 453.0970, 159.0310],
[179.9110, 454.9710, 162.3160],
[178.7370, 458.4030, 163.4610],
[179.1350, 461.9150, 162.0830],
[180.3160, 464.9940, 163.9950],
[195.1700, 454.0210, 194.1160],
[196.4220, 450.5570, 193.1610],
[195.0530, 447.5260, 191.3340],
[195.7530, 445.4320, 194.4410],
[195.6760, 446.8960, 197.9710],
[199.1230, 447.9300, 199.1640],
[198.4070, 447.0140, 202.7840],
[198.1130, 443.3080, 203.5090],
[195.0350, 441.7530, 205.1060],
[194.7040, 440.1720, 208.5300],
[191.7420, 437.9460, 207.6670],
[190.9460, 435.0990, 205.3140],
[188.0870, 435.6760, 202.8880],
[184.8300, 433.7320, 203.1830],
[184.1270, 431.8260, 199.9570],
[180.5450, 430.8500, 200.7460],
[178.2470, 433.3320, 199.0220],
[178.2770, 433.9260, 195.2680],
[176.3170, 437.2030, 195.0280],
[177.7690, 440.7330, 194.9660],
[181.3350, 439.3910, 195.1330],
[182.5820, 441.3460, 192.0940],
[182.3530, 444.9560, 193.3610],
[186.0870, 445.2240, 194.0290],
[188.7230, 447.5570, 192.6190],
[190.8760, 444.5870, 191.5490],
[188.2670, 443.5230, 188.9590],
[190.1360, 444.3010, 185.7360],
[188.8140, 444.1990, 182.1980],
[185.1300, 444.1560, 181.3470],
[182.1290, 442.2030, 182.6060],
[181.3330, 441.2700, 178.9860],
[184.3690, 440.6570, 176.7700],
[184.4240, 439.7520, 173.0740],
[187.1540, 437.5560, 171.6020],
[187.5170, 439.9750, 168.6500],
[188.4180, 443.0500, 170.7240],
[191.4360, 445.2100, 169.8910],
[193.7480, 445.6510, 172.8940],
[196.8200, 447.8990, 173.0280],
[199.4910, 448.6120, 175.6310],
[201.2080, 451.9090, 176.3780],
[203.6110, 453.4740, 178.8680],
[201.6130, 454.9210, 181.7570],
[202.0690, 458.3840, 183.2730],
[200.4970, 458.3690, 186.7180],
[197.7960, 455.8130, 187.5240],
[195.1460, 455.8470, 184.7720],
[193.6160, 452.5740, 185.9790],
[189.8600, 452.3800, 185.3150],
[189.9440, 455.8580, 183.7680],
[189.7400, 457.4310, 180.3410],
[193.1630, 457.9480, 178.8100],
[194.6770, 460.5170, 176.4640],
[197.9720, 460.6430, 174.6010],
[200.7710, 462.5710, 176.3210],
[203.7970, 463.7400, 174.3330],
[207.1150, 464.5820, 175.9770],
[210.1390, 466.5260, 174.8060],
[210.8850, 463.9390, 172.1200],
[213.6440, 461.8050, 173.6150],
[211.1280, 459.2430, 174.9140],
[208.0390, 457.7610, 173.2930],
[204.4770, 458.9110, 173.9310],
[202.5240, 457.8000, 177.0030],
[198.9190, 457.7220, 178.2440],
[197.5730, 459.9040, 181.0500],
[194.2310, 460.2350, 182.8060],
[192.1140, 462.6430, 180.7850],
[189.0250, 463.1720, 178.7000],
[187.1240, 465.8360, 176.8220],
[188.2850, 466.3320, 173.2420],
[191.5580, 464.4520, 173.8600],
[189.9950, 461.1280, 174.9010],
[191.8440, 458.2970, 173.1540],
[190.4670, 455.2690, 174.9900],
[189.7380, 453.5820, 178.3070],
[192.3390, 451.8830, 180.5000],
[191.2020, 448.5340, 181.9050],
[194.3800, 447.0280, 183.4050],
[197.8700, 447.9650, 184.5850],
[200.9840, 445.8270, 184.9660],
[202.8290, 447.6220, 187.7590],
[206.2120, 445.9180, 187.3400],
[206.6040, 447.1310, 183.7460],
[204.4430, 450.2380, 184.1830],
[202.2310, 449.1370, 181.2890],
[198.6660, 450.3790, 180.7720],
[196.2310, 448.2490, 178.7680],
[194.0550, 450.4240, 176.5260],
[190.8490, 449.7090, 174.6070],
[190.4360, 452.3080, 171.8240],
[187.3830, 454.5590, 171.7980],
[186.4670, 452.9730, 168.4480],
[185.9050, 449.6480, 170.2580],
[182.8890, 448.8910, 172.4330],
[183.5680, 448.6030, 176.1630],
[181.3330, 447.5400, 179.0590],
[182.7340, 447.5270, 182.5670],
[183.3280, 449.2260, 185.8920],
[184.8760, 452.6980, 185.6990],
[185.7140, 455.4170, 188.2030],
[182.9420, 457.9530, 188.8300],
[183.2930, 461.5650, 189.9350],
[182.7580, 462.1320, 193.6660],
[180.0620, 464.7490, 193.1030],
[176.4120, 463.9030, 192.5790],
[177.1550, 460.1730, 192.6930],
[173.8180, 459.5980, 194.4390],
[172.0770, 461.1030, 191.4040],
[173.6950, 458.5410, 189.0910],
[171.0060, 456.0190, 188.1270],
[170.8210, 453.3390, 185.4210],
[170.0180, 454.8270, 182.0350],
[171.8690, 458.1080, 182.6590],
[174.2320, 459.6520, 180.1210],
[177.8640, 459.9130, 181.2230],
[180.8780, 461.8410, 179.9490],
[184.5650, 461.8070, 180.8430],
[186.0650, 464.7990, 182.6440],
[189.5780, 466.2740, 182.6170],
[190.7120, 463.6340, 185.1390],
[189.2980, 460.6630, 183.2290],
[186.3560, 460.3000, 185.6240],
[182.7840, 459.6590, 184.5210],
[180.2130, 462.3320, 185.3590],
[176.5320, 462.8410, 184.6140],
[175.8580, 465.0300, 181.5790],
[178.1930, 465.7110, 174.9000],
[180.9150, 464.1190, 172.7670],
[181.2510, 460.5420, 174.1130],
[178.4430, 457.9660, 174.2140],
[178.5420, 456.4790, 177.7180],
[175.4730, 455.1590, 179.5500],
[175.4670, 454.0690, 183.1870],
[174.1210, 450.5430, 183.6830],
[173.9850, 450.6340, 187.5030],
[174.3640, 452.9240, 190.5240],
[177.5140, 454.2200, 192.1820],
[179.3060, 451.6570, 194.3540],
[181.4890, 453.2110, 197.0570],
[184.7390, 451.3110, 197.6330],
[186.4870, 454.0230, 199.6690],
[185.5270, 457.3810, 201.1790],
[186.7910, 459.0050, 197.9610],
[186.6010, 456.0430, 195.5320],
[183.3680, 455.3620, 193.6270],
[182.9110, 452.9430, 190.7280],
[180.0190, 452.5230, 188.3060],
[179.0910, 450.0250, 185.5960],
[178.9850, 451.7840, 182.2280],
[179.0510, 451.0080, 178.5100],
[181.0530, 453.1550, 176.0780],
[180.1320, 452.7170, 172.4060],
[182.0570, 455.4530, 170.6120],
[181.8140, 459.1820, 170.0770],
[178.5990, 461.0580, 169.3130],
[177.7620, 463.8650, 166.9020]]),
'residue_ids': tensor([ 4., 4., 4., ..., 182., 182., 182.]),
'chain_ids': tensor([0., 0., 0., ..., 1., 1., 1.])},
'feature_complex': {'atom_types': tensor([[0.],
[1.],
[2.],
...,
[1.],
[2.],
[3.]]),
'element_types': tensor([[3.],
[0.],
[0.],
...,
[0.],
[0.],
[2.]]),
'residue_types': tensor([[16.],
[16.],
[16.],
...,
[ 0.],
[ 0.],
[ 0.]]),
'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],
[132.6810, 428.2520, 163.1550],
[133.5150, 428.6750, 161.9500],
...,
[177.7620, 463.8650, 166.9020],
[177.4130, 465.0800, 167.7550],
[176.8000, 464.9490, 168.8150]]),
'residue_coordinates': tensor([[132.6810, 428.2520, 163.1550],
[133.5560, 429.2490, 159.5910],
[133.8750, 432.8980, 160.6290],
[136.1110, 431.8050, 163.5130],
[138.3420, 429.8920, 161.0830],
[138.5230, 432.9090, 158.7600],
[139.4710, 435.1730, 161.6750],
[142.1200, 432.6310, 162.6740],
[143.5890, 432.7500, 159.1600],
[143.6190, 436.5600, 159.1540],
[145.2600, 436.8040, 162.5830],
[147.8380, 434.1320, 161.7330],
[148.7390, 435.8920, 158.4790],
[149.0640, 439.2630, 160.2240],
[151.2220, 437.7670, 162.9840],
[153.4710, 435.8030, 160.6120],
[153.9420, 438.9180, 158.4710],
[156.1140, 440.2660, 161.3150],
[158.3330, 437.1860, 161.4250],
[161.8090, 436.4120, 160.1100],
[161.1120, 432.8370, 158.9350],
[157.4020, 432.0710, 159.3930],
[156.5290, 428.3870, 159.0310],
[154.0400, 428.1160, 156.1500],
[153.3520, 424.6590, 154.7510],
[152.8370, 424.8210, 150.9870],
[150.9700, 422.4160, 148.7320],
[153.5900, 422.6860, 145.9670],
[157.3050, 423.3410, 145.6600],
[158.4860, 426.6730, 144.2700],
[160.7730, 427.0880, 141.2590],
[163.9630, 429.0530, 141.9600],
[166.4230, 427.4330, 139.5550],
[167.6170, 430.8330, 138.3250],
[170.5180, 431.9210, 140.5320],
[169.7800, 435.6680, 140.3160],
[167.6790, 437.8790, 142.6110],
[167.3350, 435.1910, 145.2890],
[168.9090, 437.3000, 148.0580],
[166.1860, 439.9440, 148.5750],
[164.1010, 438.1300, 151.1890],
[163.0190, 438.8690, 154.7450],
[164.6750, 435.6560, 155.9860],
[168.1170, 436.6550, 154.6330],
[169.8040, 437.0930, 158.0010],
[173.2680, 438.4850, 158.5750],
[174.9270, 440.8490, 156.1460],
[175.9170, 440.8010, 152.5000],
[179.6060, 441.0910, 153.4330],
[180.8000, 439.9830, 156.8710],
[184.2540, 439.9980, 158.4580],
[185.5840, 437.2970, 160.7670],
[186.4140, 439.9000, 163.4480],
[182.9430, 441.4570, 163.7500],
[181.4930, 442.0910, 167.2100],
[178.0780, 440.5360, 167.9060],
[176.1650, 440.9290, 171.1790],
[173.0530, 438.9530, 172.1310],
[170.3990, 440.6990, 174.2170],
[166.7880, 440.0710, 175.2260],
[164.4480, 442.1010, 173.0340],
[160.9780, 443.6110, 173.3810],
[158.8330, 443.9000, 170.2730],
[160.1820, 442.9360, 166.8360],
[163.5010, 444.7230, 166.1630],
[164.6310, 442.4290, 163.3160],
[166.6330, 444.2560, 160.6240],
[166.1700, 447.5210, 162.5220],
[168.1350, 449.9450, 164.6440],
[168.0040, 449.0470, 168.3240],
[168.4210, 451.0540, 171.5220],
[169.0440, 450.0360, 175.1300],
[166.1650, 450.3680, 177.6100],
[166.7800, 449.8300, 181.3340],
[163.8480, 448.9150, 183.5760],
[163.3500, 443.9770, 187.9750],
[165.0240, 442.4930, 184.8870],
[168.1120, 443.0980, 182.7810],
[168.1280, 445.9210, 180.2050],
[166.7230, 445.0130, 176.8030],
[167.0690, 446.1330, 173.1930],
[164.0790, 447.8390, 171.5780],
[163.4630, 449.0630, 168.0350],
[164.4160, 452.7320, 168.0130],
[166.7460, 455.3190, 166.5790],
[167.5460, 459.0090, 166.8440],
[170.2010, 459.9620, 169.3810],
[169.5150, 456.8220, 171.4570],
[170.7440, 454.3880, 168.7860],
[172.8900, 451.6810, 170.3850],
[173.2420, 449.3000, 167.4490],
[171.4700, 447.3670, 164.7150],
[169.7840, 443.9650, 164.9670],
[170.8880, 441.4120, 162.3680],
[169.3320, 438.1640, 163.6510],
[166.7170, 436.9060, 166.1030],
[166.2830, 433.5290, 167.7900],
[162.5490, 433.4260, 168.5040],
[162.6440, 430.3540, 170.7620],
[164.9490, 432.1120, 173.2350],
[163.6530, 435.5760, 172.2210],
[167.2540, 436.7450, 171.8300],
[168.3710, 439.3650, 169.3060],
[171.8910, 439.6680, 167.8930],
[173.1360, 443.2500, 167.5870],
[176.0990, 444.7950, 165.8030],
[177.1670, 447.8660, 167.8270],
[177.4160, 451.2430, 166.1360],
[181.1790, 451.0810, 166.7480],
[181.3560, 448.4350, 163.9980],
[180.6020, 448.8920, 160.3060],
[177.8440, 446.8990, 158.6170],
[176.7540, 446.3160, 155.0150],
[173.3820, 444.6900, 154.4770],
[169.6430, 444.9660, 153.9760],
[167.6080, 446.9450, 156.5140],
[163.9630, 447.9270, 156.8800],
[163.1520, 451.2160, 155.1430],
[160.3870, 453.6570, 156.0420],
[158.9330, 453.6920, 152.5150],
[159.8050, 453.2460, 148.8360],
[161.2560, 456.7540, 148.4440],
[164.9420, 455.8330, 148.8520],
[167.3780, 455.9500, 145.9260],
[171.0380, 455.0610, 145.5060],
[173.5830, 457.5340, 146.9270],
[171.2360, 458.9430, 149.5740],
[171.8150, 459.8640, 153.2130],
[169.5910, 457.9090, 155.6000],
[168.4960, 458.3960, 159.2100],
[167.0870, 455.9020, 161.7110],
[163.6430, 457.0390, 162.8610],
[161.8700, 456.3550, 166.1660],
[160.7490, 452.9260, 164.8950],
[164.2440, 451.8750, 163.8110],
[163.4490, 452.3040, 160.1090],
[165.6820, 454.0140, 157.5590],
[164.2940, 457.2310, 156.0670],
[165.6810, 459.7290, 153.5780],
[167.3870, 462.7790, 155.0760],
[170.2060, 465.8340, 161.8360],
[172.1620, 462.7680, 160.6890],
[174.0280, 461.7040, 163.8180],
[175.2110, 458.5660, 161.9930],
[176.5410, 458.2410, 158.4450],
[174.5550, 455.7530, 156.3540],
[174.3580, 455.6030, 152.5600],
[171.8590, 453.6260, 150.4910],
[173.6300, 451.2730, 148.0810],
[170.4820, 450.1750, 146.2000],
[166.8520, 451.1980, 145.7810],
[164.0730, 450.1880, 148.1490],
[162.6710, 446.7750, 147.1980],
[159.1420, 445.8080, 148.2240],
[158.3160, 442.3640, 149.5950],
[155.0300, 443.1470, 151.4050],
[152.6800, 446.0960, 151.8500],
[155.0610, 447.4890, 154.4910],
[158.2010, 445.3460, 154.0190],
[160.8090, 447.5340, 152.2870],
[164.4130, 446.3150, 152.1360],
[167.3360, 448.5380, 151.1380],
[171.0310, 447.6630, 150.9500],
[172.8160, 450.1440, 153.2260],
[176.3550, 450.4540, 154.5850],
[176.6310, 452.0780, 158.0240],
[180.1600, 453.0970, 159.0310],
[179.9110, 454.9710, 162.3160],
[178.7370, 458.4030, 163.4610],
[179.1350, 461.9150, 162.0830],
[180.3160, 464.9940, 163.9950],
[195.1700, 454.0210, 194.1160],
[196.4220, 450.5570, 193.1610],
[195.0530, 447.5260, 191.3340],
[195.7530, 445.4320, 194.4410],
[195.6760, 446.8960, 197.9710],
[199.1230, 447.9300, 199.1640],
[198.4070, 447.0140, 202.7840],
[198.1130, 443.3080, 203.5090],
[195.0350, 441.7530, 205.1060],
[194.7040, 440.1720, 208.5300],
[191.7420, 437.9460, 207.6670],
[190.9460, 435.0990, 205.3140],
[188.0870, 435.6760, 202.8880],
[184.8300, 433.7320, 203.1830],
[184.1270, 431.8260, 199.9570],
[180.5450, 430.8500, 200.7460],
[178.2470, 433.3320, 199.0220],
[178.2770, 433.9260, 195.2680],
[176.3170, 437.2030, 195.0280],
[177.7690, 440.7330, 194.9660],
[181.3350, 439.3910, 195.1330],
[182.5820, 441.3460, 192.0940],
[182.3530, 444.9560, 193.3610],
[186.0870, 445.2240, 194.0290],
[188.7230, 447.5570, 192.6190],
[190.8760, 444.5870, 191.5490],
[188.2670, 443.5230, 188.9590],
[190.1360, 444.3010, 185.7360],
[188.8140, 444.1990, 182.1980],
[185.1300, 444.1560, 181.3470],
[182.1290, 442.2030, 182.6060],
[181.3330, 441.2700, 178.9860],
[184.3690, 440.6570, 176.7700],
[184.4240, 439.7520, 173.0740],
[187.1540, 437.5560, 171.6020],
[187.5170, 439.9750, 168.6500],
[188.4180, 443.0500, 170.7240],
[191.4360, 445.2100, 169.8910],
[193.7480, 445.6510, 172.8940],
[196.8200, 447.8990, 173.0280],
[199.4910, 448.6120, 175.6310],
[201.2080, 451.9090, 176.3780],
[203.6110, 453.4740, 178.8680],
[201.6130, 454.9210, 181.7570],
[202.0690, 458.3840, 183.2730],
[200.4970, 458.3690, 186.7180],
[197.7960, 455.8130, 187.5240],
[195.1460, 455.8470, 184.7720],
[193.6160, 452.5740, 185.9790],
[189.8600, 452.3800, 185.3150],
[189.9440, 455.8580, 183.7680],
[189.7400, 457.4310, 180.3410],
[193.1630, 457.9480, 178.8100],
[194.6770, 460.5170, 176.4640],
[197.9720, 460.6430, 174.6010],
[200.7710, 462.5710, 176.3210],
[203.7970, 463.7400, 174.3330],
[207.1150, 464.5820, 175.9770],
[210.1390, 466.5260, 174.8060],
[210.8850, 463.9390, 172.1200],
[213.6440, 461.8050, 173.6150],
[211.1280, 459.2430, 174.9140],
[208.0390, 457.7610, 173.2930],
[204.4770, 458.9110, 173.9310],
[202.5240, 457.8000, 177.0030],
[198.9190, 457.7220, 178.2440],
[197.5730, 459.9040, 181.0500],
[194.2310, 460.2350, 182.8060],
[192.1140, 462.6430, 180.7850],
[189.0250, 463.1720, 178.7000],
[187.1240, 465.8360, 176.8220],
[188.2850, 466.3320, 173.2420],
[191.5580, 464.4520, 173.8600],
[189.9950, 461.1280, 174.9010],
[191.8440, 458.2970, 173.1540],
[190.4670, 455.2690, 174.9900],
[189.7380, 453.5820, 178.3070],
[192.3390, 451.8830, 180.5000],
[191.2020, 448.5340, 181.9050],
[194.3800, 447.0280, 183.4050],
[197.8700, 447.9650, 184.5850],
[200.9840, 445.8270, 184.9660],
[202.8290, 447.6220, 187.7590],
[206.2120, 445.9180, 187.3400],
[206.6040, 447.1310, 183.7460],
[204.4430, 450.2380, 184.1830],
[202.2310, 449.1370, 181.2890],
[198.6660, 450.3790, 180.7720],
[196.2310, 448.2490, 178.7680],
[194.0550, 450.4240, 176.5260],
[190.8490, 449.7090, 174.6070],
[190.4360, 452.3080, 171.8240],
[187.3830, 454.5590, 171.7980],
[186.4670, 452.9730, 168.4480],
[185.9050, 449.6480, 170.2580],
[182.8890, 448.8910, 172.4330],
[183.5680, 448.6030, 176.1630],
[181.3330, 447.5400, 179.0590],
[182.7340, 447.5270, 182.5670],
[183.3280, 449.2260, 185.8920],
[184.8760, 452.6980, 185.6990],
[185.7140, 455.4170, 188.2030],
[182.9420, 457.9530, 188.8300],
[183.2930, 461.5650, 189.9350],
[182.7580, 462.1320, 193.6660],
[180.0620, 464.7490, 193.1030],
[176.4120, 463.9030, 192.5790],
[177.1550, 460.1730, 192.6930],
[173.8180, 459.5980, 194.4390],
[172.0770, 461.1030, 191.4040],
[173.6950, 458.5410, 189.0910],
[171.0060, 456.0190, 188.1270],
[170.8210, 453.3390, 185.4210],
[170.0180, 454.8270, 182.0350],
[171.8690, 458.1080, 182.6590],
[174.2320, 459.6520, 180.1210],
[177.8640, 459.9130, 181.2230],
[180.8780, 461.8410, 179.9490],
[184.5650, 461.8070, 180.8430],
[186.0650, 464.7990, 182.6440],
[189.5780, 466.2740, 182.6170],
[190.7120, 463.6340, 185.1390],
[189.2980, 460.6630, 183.2290],
[186.3560, 460.3000, 185.6240],
[182.7840, 459.6590, 184.5210],
[180.2130, 462.3320, 185.3590],
[176.5320, 462.8410, 184.6140],
[175.8580, 465.0300, 181.5790],
[178.1930, 465.7110, 174.9000],
[180.9150, 464.1190, 172.7670],
[181.2510, 460.5420, 174.1130],
[178.4430, 457.9660, 174.2140],
[178.5420, 456.4790, 177.7180],
[175.4730, 455.1590, 179.5500],
[175.4670, 454.0690, 183.1870],
[174.1210, 450.5430, 183.6830],
[173.9850, 450.6340, 187.5030],
[174.3640, 452.9240, 190.5240],
[177.5140, 454.2200, 192.1820],
[179.3060, 451.6570, 194.3540],
[181.4890, 453.2110, 197.0570],
[184.7390, 451.3110, 197.6330],
[186.4870, 454.0230, 199.6690],
[185.5270, 457.3810, 201.1790],
[186.7910, 459.0050, 197.9610],
[186.6010, 456.0430, 195.5320],
[183.3680, 455.3620, 193.6270],
[182.9110, 452.9430, 190.7280],
[180.0190, 452.5230, 188.3060],
[179.0910, 450.0250, 185.5960],
[178.9850, 451.7840, 182.2280],
[179.0510, 451.0080, 178.5100],
[181.0530, 453.1550, 176.0780],
[180.1320, 452.7170, 172.4060],
[182.0570, 455.4530, 170.6120],
[181.8140, 459.1820, 170.0770],
[178.5990, 461.0580, 169.3130],
[177.7620, 463.8650, 166.9020]]),
'residue_ids': tensor([ 4., 4., 4., ..., 182., 182., 182.]),
'chain_ids': tensor([0., 0., 0., ..., 1., 1., 1.])},
'id': '8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
'sample_id': '8phr__X4_UNDEFINED-R--8phr__W4_UNDEFINED-L',
'target_id': '8phr__X4_UNDEFINED-R--8phr__W4_UNDEFINED-L'}
# Since we used the default option of crop_equal_monomer_shapes, we should expect feature and target complex coords are identical shapes
assert (
data_item["feature_complex"]["atom_coordinates"].shape
== data_item["target_complex"]["atom_coordinates"].shape
)
data_item["feature_complex"]["atom_coordinates"].shape
torch.Size([1316, 3])
help(PinderDataset)
Help on class PinderDataset in module pinder.core.loader.dataset:
class PinderDataset(torch.utils.data.dataset.Dataset)
| PinderDataset(split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms_target: 'list[StructureTransform]' = [], structure_transforms_feature: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'
|
| Method resolution order:
| PinderDataset
| torch.utils.data.dataset.Dataset
| typing.Generic
| builtins.object
|
| Methods defined here:
|
| __getitem__(self, idx: 'int') -> 'dict[str, dict[str, torch.Tensor] | torch.Tensor]'
|
| __init__(self, split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms_target: 'list[StructureTransform]' = [], structure_transforms_feature: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = <function structure2tensor_transform at 0x7f69602b55a0>, ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'
| Initialize self. See help(type(self)) for accurate signature.
|
| __len__(self) -> 'int'
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __annotations__ = {}
|
| __parameters__ = ()
|
| ----------------------------------------------------------------------
| Methods inherited from torch.utils.data.dataset.Dataset:
|
| __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
|
| ----------------------------------------------------------------------
| Data descriptors inherited from torch.utils.data.dataset.Dataset:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes inherited from torch.utils.data.dataset.Dataset:
|
| __orig_bases__ = (typing.Generic[+T_co],)
|
| ----------------------------------------------------------------------
| Class methods inherited from typing.Generic:
|
| __class_getitem__(params) from builtins.type
|
| __init_subclass__(*args, **kwargs) from builtins.type
| This method is called when a class is subclassed.
|
| The default implementation does nothing. It may be
| overridden to extend subclasses.
Torch DataLoader for PinderDataset#
The PinderDataset
can be served by a torch.utils.data.DataLoader
.
There is a convenience function pinder.core.loader.dataset.get_torch_loader
for taking a PinderDataset
and returning a DataLoader
for the dataset object.
We can leverage the default collate_fn
(pinder.core.loader.dataset.collate_batch
) to merge multiple systems (Dataset
items) to create mini-batches of tensors:
from pinder.core.loader.dataset import collate_batch, get_torch_loader
from torch.utils.data import DataLoader
batch_size = 2
train_dataloader = get_torch_loader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_batch,
num_workers=0,
)
assert isinstance(train_dataloader, DataLoader)
assert hasattr(train_dataloader, "dataset")
# Get a batch from the dataloader
batch = next(iter(train_dataloader))
# expected batch dict keys
assert set(batch.keys()) == {
"target_complex",
"feature_complex",
"id",
"sample_id",
"target_id",
}
assert isinstance(batch["target_complex"], dict)
assert isinstance(batch["target_complex"]["atom_coordinates"], torch.Tensor)
feature_coords = batch["feature_complex"]["atom_coordinates"]
# Ensure batch size propagates to tensor dims
assert feature_coords.shape[0] == batch_size
# Ensure coordinates have dim 3
assert feature_coords.shape[2] == 3
2024-11-15 12:13:41,136 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
2024-11-15 12:13:41,860 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
Torch geometric Dataset#
# Make sure to install torch_cluster
# !pip install torch_cluster
from pinder.core.loader.dataset import PPIDataset
from pinder.core.loader.geodata import NodeRepresentation
nodes = {NodeRepresentation("atom"), NodeRepresentation("residue")}
train_dataset = PPIDataset(
node_types=nodes,
split="train",
monomer1="holo_receptor",
monomer2="holo_ligand",
limit_by=5,
force_reload=True,
parallel=False,
)
assert len(train_dataset) == 5
train_dataset
Processing...
2024-11-15 12:13:49,610 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:13:50,822 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=5
2024-11-15 12:13:51,243 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:13:52,054 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:13:52,436 | pinder.core.loader.dataset:550 | INFO : Finished processing, only 5 systems
Done!
PPIDataset(5)
from torch_geometric.data import HeteroData
from pinder.core import get_index
# Here we check that we correctly capped the number of systems to load to 5 (via limit_by)
pindex = get_index()
raw_ids = set(train_dataset.raw_file_names)
assert len(raw_ids.intersection(set(pindex.id))) == 5
processed_ids = {f.stem for f in train_dataset.processed_file_names}
# Here we ensure that all 5 ids got processed and saved as .pt file on disk
assert len(processed_ids.intersection(set(pindex.id))) == 5
# Let's get an item from the dataset by index
data_item = train_dataset[0]
assert isinstance(data_item, HeteroData)
data_item
PairedPDB(
ligand_residue={
residueid=[121, 1],
pos=[121, 3],
edge_index=[2, 1210],
chain=[1],
},
receptor_residue={
residueid=[144, 1],
pos=[144, 3],
edge_index=[2, 1440],
chain=[1],
},
ligand_atom={
x=[958, 1],
pos=[958, 3],
edge_index=[2, 9580],
},
receptor_atom={
x=[1119, 1],
pos=[1119, 3],
edge_index=[2, 11190],
},
pdb={
id=[1],
num_nodes=1,
}
)
data_item.num_node_features
{'ligand_residue': 0,
'receptor_residue': 0,
'ligand_atom': 1,
'receptor_atom': 1,
'pdb': 0}
data_item["ligand_atom"]
{'x': tensor([[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[10.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[22.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 7.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[11.],
[30.],
[ 5.],
[14.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[11.],
[30.],
[ 5.],
[14.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 7.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[10.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 6.],
[31.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[22.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[33.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[11.],
[30.],
[ 5.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 6.],
[31.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[26.],
[25.],
[11.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 7.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[34.],
[30.],
[ 4.],
[13.],
[16.],
[18.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 6.],
[31.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 7.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 7.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[22.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[11.],
[30.],
[ 5.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[33.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[26.],
[25.],
[11.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[11.],
[30.],
[ 5.],
[14.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[17.],
[ 5.],
[ 9.],
[20.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[10.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[27.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[31.],
[28.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[35.],
[23.],
[ 7.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[21.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[29.],
[12.],
[24.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[10.],
[23.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[19.],
[32.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[ 0.],
[ 1.],
[ 2.],
[ 3.],
[ 8.],
[15.],
[ 7.],
[25.],
[11.],
[30.],
[ 5.],
[ 0.],
[ 1.],
[ 2.],
[ 3.]]), 'pos': tensor([[207.6910, 164.0670, 172.7860],
[208.6350, 163.8290, 173.9130],
[209.3790, 165.0800, 174.3180],
...,
[240.2730, 173.0310, 178.0680],
[240.1680, 174.1280, 177.0260],
[240.3170, 173.8780, 175.8300]]), 'edge_index': tensor([[ 1, 2, 4, ..., 946, 894, 893],
[ 0, 0, 0, ..., 957, 957, 957]])}
# We can also get an item by its system ID
data_item = train_dataset.get_filename("8phr__X4_UNDEFINED--8phr__W4_UNDEFINED")
data_item
PairedPDB(
ligand_residue={
residueid=[158, 1],
pos=[158, 3],
edge_index=[2, 1580],
chain=[1],
},
receptor_residue={
residueid=[171, 1],
pos=[171, 3],
edge_index=[2, 1710],
chain=[1],
},
ligand_atom={
x=[1198, 1],
pos=[1198, 3],
edge_index=[2, 11980],
},
receptor_atom={
x=[1358, 1],
pos=[1358, 3],
edge_index=[2, 13580],
},
pdb={
id=[1],
num_nodes=1,
}
)
train_dataset.print_summary()
PPIDataset (#graphs=5):
+------------+----------+----------+
| | #nodes | #edges |
|------------+----------+----------|
| mean | 2612.4 | 0 |
| std | 1631.6 | 0 |
| min | 999 | 0 |
| quantile25 | 1600 | 0 |
| median | 2343 | 0 |
| quantile75 | 2886 | 0 |
| max | 5234 | 0 |
+------------+----------+----------+
Number of nodes per node type:
+------------+------------------+--------------------+---------------+-----------------+-------+
| | ligand_residue | receptor_residue | ligand_atom | receptor_atom | pdb |
|------------+------------------+--------------------+---------------+-----------------+-------|
| mean | 140.2 | 152.8 | 1103.4 | 1215 | 1 |
| std | 90.2 | 87.5 | 738.9 | 717.1 | 0 |
| min | 58 | 62 | 426 | 452 | 1 |
| quantile25 | 78 | 97 | 622 | 802 | 1 |
| median | 121 | 144 | 958 | 1119 | 1 |
| quantile75 | 158 | 171 | 1198 | 1358 | 1 |
| max | 286 | 290 | 2313 | 2344 | 1 |
+------------+------------------+--------------------+---------------+-----------------+-------+
PairedPDB torch-geometric HeteroData object#
The PPIDataset
represents its data items as PairedPDB
objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via torch_geometric.data.HeteroData
, holding multiple node and/or edge types in disjunct storage objects.
It leverages the PinderLoader
to apply optional filters and/or transforms and implements a caching system by saving processed systems to disk in .pt
files.
The PairedPDB
implements a conversion method .from_pinder_system
that takes in a PinderSystem
and converts it to a PairedPDB
object.
For more details on the torch-geometric Dataset APIs, please refer to the tutorials.
from pinder.core.loader.geodata import PairedPDB
from torch_geometric.data import HeteroData
pinder_id = "3s9d__B1_P48551--3s9d__A1_P01563"
system = PinderSystem(pinder_id)
holo_data = PairedPDB.from_pinder_system(
system=system,
monomer1="holo_receptor", monomer2="holo_ligand",
node_types=nodes,
)
assert isinstance(holo_data, HeteroData)
expected_node_types = [
'ligand_residue', 'receptor_residue', 'ligand_atom', 'receptor_atom'
]
assert holo_data.num_nodes == 2780
assert holo_data.num_edges == 0
assert isinstance(holo_data.num_node_features, dict)
expected_num_feats = {
'ligand_residue': 0,
'receptor_residue': 0,
'ligand_atom': 1,
'receptor_atom': 1
}
for k, v in expected_num_feats.items():
assert holo_data.num_node_features[k] == v
assert holo_data.node_types == expected_node_types
holo_data
2024-11-15 12:13:53,464 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=15
PairedPDB(
ligand_residue={
residueid=[127, 1],
pos=[127, 3],
edge_index=[2, 1270],
chain=[1],
},
receptor_residue={
residueid=[180, 1],
pos=[180, 3],
edge_index=[2, 1800],
chain=[1],
},
ligand_atom={
x=[1032, 1],
pos=[1032, 3],
edge_index=[2, 10320],
},
receptor_atom={
x=[1441, 1],
pos=[1441, 3],
edge_index=[2, 14410],
}
)
# You can target specific monomers (apo/holo/pred) for the receptor and ligand
apo_data = PairedPDB.from_pinder_system(
system=system,
monomer1="apo_receptor", monomer2="apo_ligand",
node_types=nodes,
)
assert isinstance(apo_data, HeteroData)
assert apo_data.num_nodes == 3437
assert apo_data.num_edges == 0
assert isinstance(apo_data.num_node_features, dict)
expected_num_feats = {
'ligand_residue': 0,
'receptor_residue': 0,
'ligand_atom': 1,
'receptor_atom': 1,
}
for k, v in expected_num_feats.items():
assert apo_data.num_node_features[k] == v
assert apo_data.node_types == expected_node_types
apo_data
PairedPDB(
ligand_residue={
residueid=[165, 1],
pos=[165, 3],
edge_index=[2, 1650],
chain=[1],
},
receptor_residue={
residueid=[212, 1],
pos=[212, 3],
edge_index=[2, 2120],
chain=[1],
},
ligand_atom={
x=[1350, 1],
pos=[1350, 3],
edge_index=[2, 13500],
},
receptor_atom={
x=[1710, 1],
pos=[1710, 3],
edge_index=[2, 17100],
}
)
Torch geometric DataLoader#
The PPIDataset
can be served by a torch_geometric.DataLoader
.
There is a convenience function pinder.core.loader.dataset.get_geo_loader
for taking a PPIDataset and returning a DataLoader
for the dataset object.
from pinder.core.loader.dataset import get_geo_loader
from torch_geometric.loader import DataLoader
loader = get_geo_loader(train_dataset)
assert isinstance(loader, DataLoader)
assert hasattr(loader, "dataset")
ds = loader.dataset
assert len(ds) == 5
loader
<torch_geometric.loader.dataloader.DataLoader at 0x7f697dbfdb40>
ds
PPIDataset(5)
Implementing your own PyTorch Dataset & DataLoader for pinder#
While the previous sections covered two example implementations of leveraging the torch and torch-geometric APIs, below we will illustrate how you can write your own PyTorch data pipeline to feed your model.
We will focus on writing a minimalistic example that simply fetches PDB files and returns the coordinates as tensors. We will also include an example of a sampler function for the dataloader to implement diversity sampling in your workflow.
Defining the Dataset#
Below we will write a barebones torch.utils.data.Dataset
object that implements at a minimum:
__init__
method__len__
method returning the number of items in the dataset__getitem__
method that returns an item in the dataset
We will also add an option to include apo and predicted monomer types by adding a monomer_priority
argument to our dataset. The argument will be set to “random” by default, indicating that we want to randomly sample a monomer type from the set of monomers available for a given system. We could also target a specific monomer type by setting this argument to one of “holo”, “apo”, “predicted”. Not every system in the training set has apo or predicted structures available, so we will also add a fallback_to_holo
argument to indicate whether we want to use holo
monomers when the selected monomer type is not available.
We also define two interfaces for applying filters to our dataset:
metadata_filter
: a query string to apply to the pinder metadata pandas DataFramesystem_filters: list[PinderFilterBase]
: a list of filters that inheret a base class,PinderFilterBase
, which serves as the abstraction layer for definingPinderSystem
-based filtersstructure_filters: list[StructureFilter]
: a list of filters that inheret a base class,StructureFilter
, which serves as the abstraction layer for definingStructure
-based filters
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import Dataset
from pinder.core import get_index, get_metadata
from pinder.core.loader.loader import _create_target_feature_complex, select_monomer
from pinder.core.loader.structure import Structure
index = get_index()
metadata = get_metadata()
class CustomPinderDataset(Dataset):
def __init__(
self,
split: str,
monomer_priority: str = "random",
fallback_to_holo: bool = True,
crop_equal_monomer_shapes: bool = True,
use_canonical_apo: bool = True,
transform=None,
target_transform=None,
structure_filters: list[filters.StructureFilter] = [],
system_filters: list[filters.PinderFilterBase] = [],
metadata_filter: str | None = None,
max_load_attempts: int = 10,
) -> None:
# Which split/mode we are using for the dataset instance
self.split = split
self.monomer_priority = monomer_priority
self.fallback_to_holo = fallback_to_holo
self.crop_equal_monomer_shapes = crop_equal_monomer_shapes
# Optional transform and target transform to apply (will be covered shortly)
self.transform = transform
self.target_transform = target_transform
# Optional system-level filters to apply
self.system_filters = system_filters
# Optional structure filters to apply
self.structure_filters = structure_filters
# Maximum number of times to try sampling another index from the dataset until an exception is raised
self.max_load_attempts = max_load_attempts
# Whether we should use canonical apo structures (apo_R/L_pdb columns in pinder index) if apo monomers are selected
self.use_canonical_apo = use_canonical_apo
# Define the subset of the pinder index and metadata corresponding to the split of our dataset instance
self.index = index.query(f'split == "{split}"').reset_index(drop=True)
self.metadata = metadata[metadata["id"].isin(set(self.index.id))].reset_index(drop=True)
if metadata_filter:
try:
self.metadata = self.metadata.query(metadata_filter).reset_index(drop=True)
except Exception as e:
print(f"Failed to apply metadata_filter={metadata_filter}: {e}")
self.index = self.index[self.index["id"].isin(set(self.metadata.id))].reset_index(drop=True)
def __len__(self):
return len(self.index)
def __getitem__(self, idx: int) -> tuple[NDArray[np.double], NDArray[np.double]]:
valid_idx = False
attempts = 0
while not valid_idx and attempts < self.max_load_attempts:
attempts += 1
row = self.index.iloc[idx]
system = PinderSystem(row.id)
system = self.apply_system_filters(system)
if not isinstance(system, PinderSystem):
continue
selected_monomers = select_monomer(
row,
self.monomer_priority,
self.fallback_to_holo,
self.use_canonical_apo,
)
# With the system and selected_monomers objects, we can now create a pair of dimer complexes
# Below we leverage the existing utility from the PinderLoader (_create_target_feature_complex)
target_complex, feature_complex = _create_target_feature_complex(
system, selected_monomers, self.crop_equal_monomer_shapes, self.fallback_to_holo
)
valid_idx = self.apply_structure_filters(target_complex)
if not valid_idx:
# Try another index before raising IndexError
idx = random.choice(list(range(len(self))))
if not valid_idx:
raise IndexError(
f"Unable to find a valid item in the dataset satisfying filters at {idx} after {attempts} attempts!"
)
if self.transform is not None:
feature_complex = self.transform(feature_complex)
if self.target_transform is not None:
target_complex = self.target_transform(target_complex)
return feature_complex, target_complex
def apply_structure_filters(self, structure: Structure) -> bool:
pass_filters = True
for structure_filter in self.structure_filters:
if not structure_filter(structure):
pass_filters = False
break
return pass_filters
def apply_system_filters(self, system: PinderSystem) -> PinderSystem | bool:
for system_filter in self.system_filters:
if isinstance(system_filter, filters.PinderFilterBase):
if not base_filter(system):
return False
return system
def __repr__(self) -> str:
return f"CustomPinderDataset(split={self.split}, monomers={self.monomer_priority}, systems={len(self)})"
# Note the selected monomers indicated by Structure.pinder_id attributes. Since we enabled cropping, the feature and target complex AtomArray have identical shapes
test_data = CustomPinderDataset(split="test")
test_data[0]
(Structure(
filepath=/home/runner/.local/share/pinder/2024-02/test_set_pdbs/7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L.pdb,
uniprot_map=<class 'pandas.core.frame.DataFrame'> with shape (294, 14),
pinder_id='7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L',
atom_array=<class 'biotite.structure.AtomArray'> with shape (2092,),
pdb_engine='fastpdb',
),
Structure(
filepath=/home/runner/.local/share/pinder/2024-02/test_set_pdbs/7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L.pdb,
uniprot_map=<class 'pandas.core.frame.DataFrame'> with shape (294, 14),
pinder_id='7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L',
atom_array=<class 'biotite.structure.AtomArray'> with shape (2092,),
pdb_engine='fastpdb',
))
In the above example, we wrote a torch Dataset that currently returns a tuple of Structure
objects (one representing the “decoy” or sample to use for features and the other representing the ground-truth “label”)
Of course we wouldn’t use these Structure
objects directly in our model. In your model, you will have to choose which features to compute and which data structure to use to represent them.
While the PinderDataset
and PPIDataset
datasets provide examples of feature encodings, below we will adjust the transform
and target_transform
objects to simply return NumPy NDArray objects as a default. Note: you can’t pass Structure
objects to torch DataLoader, you must first convert them into array-like structures supported by the default torch collate_fn. You can implement your own collate functions to adjust this behavior.
def default_transform(structure: Structure) -> NDArray[np.double]:
return structure.coords
test_data = CustomPinderDataset(split="test", transform=default_transform, target_transform=default_transform)
test_data[0]
(array([[-19.331, -10.568, -4.591],
[-18.082, -10.305, -3.883],
[-16.867, -10.692, -4.727],
...,
[ 13.363, 11.96 , -6.092],
[ 15.081, 9.028, -5.759],
[ 12.459, 10.17 , -6.767]], dtype=float32),
array([[-13.215324 , -11.076902 , 15.214827 ],
[-14.133494 , -10.301851 , 14.38616 ],
[-13.614427 , -8.882866 , 14.150835 ],
...,
[ 6.8049655, -15.96424 , 20.159506 ],
[ 9.54528 , -16.992254 , 18.400843 ],
[ 7.021378 , -15.329907 , 18.152586 ]], dtype=float32))
Implementing diversity sampling#
Here, we provide an example of how one might use torch.utils.data.WeightedRandomSampler
. However, users are free to sample diversity any way they see fit. For this example, we are going to sample diversity inversely proportional to pinder cluster population size.
from torch.utils.data import WeightedRandomSampler
def inverse_cluster_size_sampler(dataset: PinderDataset, replacement: bool = True):
index = dataset.index
cluster_counts = (
index["cluster_id"].value_counts().rename("cluster_count")
)
index = index.merge(
cluster_counts, left_on="cluster_id", right_index=True
)
# undersample large clusters
cluster_weights = 1.0 / torch.tensor(index.cluster_count.values)
return WeightedRandomSampler(
weights=cluster_weights,
num_samples=len(
cluster_counts
),
replacement=replacement,
)
sampler = inverse_cluster_size_sampler(
test_data,
replacement=True,
)
sampler
<torch.utils.data.sampler.WeightedRandomSampler at 0x7f697dc7e200>
Defining the dataloader#
Now that we have implemented a dataset and sampling function, we can tie everything together to implement the DataLoader
.
from torch.utils.data import DataLoader
test_dataloader = DataLoader(
test_data,
batch_size=1,
# Mutually exclusive with sampler
shuffle=False,
sampler=sampler,
)
test_features, test_labels = next(iter(test_dataloader))
test_features, test_labels
2024-11-15 12:13:55,075 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
(tensor([[[ -0.7287, 16.4913, 9.8068],
[ -1.6878, 16.1829, 10.8788],
[ -2.9926, 16.9600, 10.6499],
...,
[ -3.8695, -10.1312, 23.1116],
[ -1.7981, -9.4029, 23.5338],
[ -2.8393, -9.4165, 22.7493]]]),
tensor([[[-0.9892, 19.1343, 8.9643],
[-1.5965, 18.3829, 10.0574],
[-3.1128, 18.5439, 10.0904],
...,
[-7.1361, -8.3044, 11.7833],
[-8.0210, -9.1269, 12.3305],
[-7.5183, -7.4610, 10.8316]]]))
Putting it all together, we can now get a train/val/test dataloader as such:
from typing import Any, Callable
def get_loader(
dataset: CustomPinderDataset,
sampler: torch.utils.data.Sampler | None = inverse_cluster_size_sampler,
batch_size: int = 2,
# shuffle is mutually exclusive with sampler
shuffle: bool = False,
num_workers: int = 0,
collate_fn: Callable[[list[tuple[NDArray[np.double], NDArray[np.double]]]], tuple[torch.Tensor, torch.Tensor]] | None = None,
**kwargs: Any,
) -> "DataLoader[CustomPinderDataset]":
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
**kwargs,
)
train_data = CustomPinderDataset(
split="train",
structure_filters=[filters.MinAtomTypesFilter()],
metadata_filter="(buried_sasa >= 500)",
transform=default_transform,
target_transform=default_transform,
)
train_dataloader = get_loader(
train_data,
sampler=inverse_cluster_size_sampler(train_data),
batch_size=1,
)
train_dataloader
<torch.utils.data.dataloader.DataLoader at 0x7f69817f7e80>
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels
2024-11-15 12:14:00,725 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
(tensor([[[161.0880, 121.9010, 154.0680],
[159.7830, 122.5170, 153.9220],
[159.8500, 124.0310, 153.9140],
...,
[142.9230, 142.9120, 73.1190],
[142.6480, 142.3420, 74.4990],
[142.6460, 140.8530, 74.4910]]]),
tensor([[[161.0880, 121.9010, 154.0680],
[159.7830, 122.5170, 153.9220],
[159.8500, 124.0310, 153.9140],
...,
[142.9230, 142.9120, 73.1190],
[142.6480, 142.3420, 74.4990],
[142.6460, 140.8530, 74.4910]]]))
Using a larger batch size (collate_fn)#
What if we want to use a larger batch size?
By default, the collate_fn
used by the DataLoader is torch.utils.data._utils.collate.default_collate
which expects to be able to stack tensors via torch.stack
.
Since different pinder systems have a differing number of atomic coordinates, using a batch size greater than 1 will cause this function to raise a RuntimeError with a message like:
RuntimeError: stack expects each tensor to be equal size, but got [688, 3] at entry 0 and [1391, 3] at entry 1
We can leverage existing pinder utilities to adapt our default collate_fn to pad tensor dimensions with dummy values so that they can be stacked.
from pinder.core.loader.dataset import pad_and_stack
help(pad_and_stack)
Help on function pad_and_stack in module pinder.core.loader.dataset:
pad_and_stack(tensors: 'list[Tensor]', dim: 'int' = 0, dims_to_pad: 'list[int] | None' = None, value: 'int | float | None' = None) -> 'Tensor'
Pads a list of tensors to the maximum length observed along each dimension and then stacks them along a new dimension (given by `dim`).
Parameters:
tensors (list[Tensor]): A list of tensors to pad and stack
dim (int): The new dimension to stack along.
dims_to_pad (list[int] | None): The dimensions to pad
value (int | float | None, optional): The value to pad with, by default None
Returns:
Tensor: The padded and stacked tensor. Below are examples of input and output shapes
Example 1: Sequence features (although redundant with torch.rnn.utils.pad_sequence)
input: [(2,), (7,)], dim: 0
output: (2, 7)
Example 2: Pair features (e.g., pairwise coordinates)
input: [(4, 4, 3), (7, 7, 3)], dim: 0
output: (2, 7, 7, 3)
def collate_coordinates(batch, coords_pad_value: int = -100):
feature_coords = []
target_coords = []
for x in batch:
feat, target = x
if isinstance(feat, np.ndarray):
feat = torch.tensor(feat, dtype=torch.float32)
if isinstance(target, np.ndarray):
target = torch.tensor(target, dtype=torch.float32)
feature_coords.append(feat)
target_coords.append(target)
feature_coords = pad_and_stack(feature_coords, dim=0, value=coords_pad_value)
target_coords = pad_and_stack(target_coords, dim=0, value=coords_pad_value)
return feature_coords, target_coords
train_dataloader = get_loader(
train_data,
sampler=inverse_cluster_size_sampler(train_data),
collate_fn=collate_coordinates,
batch_size=2,
)
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels
2024-11-15 12:14:01,733 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
2024-11-15 12:14:13,286 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=4, items=7
(tensor([[[ 246.8259, 207.0484, 365.5238],
[ 246.7659, 208.4906, 365.3517],
[ 247.3034, 208.7366, 363.9514],
...,
[ 197.3790, 222.5565, 382.4614],
[ 196.3958, 221.6868, 381.6732],
[ 196.9519, 221.3230, 380.3475]],
[[ 367.6263, 324.4850, 301.8851],
[ 367.0830, 325.2843, 302.9723],
[ 367.2429, 326.7630, 302.6137],
...,
[-100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000]]]),
tensor([[[ 251.3640, 191.2350, 360.7730],
[ 251.2140, 190.5020, 359.5220],
[ 251.7570, 191.3220, 358.3540],
...,
[ 203.7410, 227.2750, 385.3440],
[ 204.1620, 227.1150, 383.8930],
[ 204.7830, 225.7870, 383.6390]],
[[ 367.8460, 324.9710, 300.2600],
[ 367.1610, 325.4500, 301.4540],
[ 367.0780, 326.9710, 301.4500],
...,
[-100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000]]]))