# Pinder loader


The goal of this tutorial is to provide some hands-on examples of how one can leverage the `pinder` dataset in their ML workflow. Specifically, we will illustrate how you can use various utilities provided by `pinder` to write your own data pipeline. 

While this tutorial will not go into details about how to write your own model, it will cover the basic groundwork necessary to interface with structures in pinder and the associated splits and metadata. You will of course want to implement your own featurization pipelines, data representations, etc. but the hope is that this tutorial clarifies how to access the data and make use of it in an ML framework. 


Before proceeding with this tutorial section, you may find it helpful to review the existing tutorials available in `pinder`. 

Specifcially, the tutorials covering:
* [pinder index](https://pinder-org.github.io/pinder/pinder-index.html)
* [pinder system](https://pinder-org.github.io/pinder/pinder-system.html)
* [cropped superposition](https://pinder-org.github.io/pinder/superposition.html)


## Accessing and loading data for training

In order to access the train and val splits for PINDER, please refer to the [pinder documentation](https://github.com/pinder-org/pinder/tree/main?tab=readme-ov-file#%EF%B8%8F-getting-the-dataset)

Once you have downloaded the pinder dataset, either via the `pinder` package or directly through `gsutil`, you will have all of the necessary files for training. 

To get a list of those systems and their split labels, refer to the `pinder` index. 

**We will start by looking at the most basic way to load items from the training and validation set: via `PinderSystem` objects**

### Recap: PinderSystem and Structure classes

In [1]:
import torch

from pinder.core import get_index, PinderSystem

def get_system(system_id: str) -> PinderSystem:
 return PinderSystem(system_id)


index = get_index()
train = index[index.split == "train"].copy()
system = get_system(train.id.iloc[0])
system
 

PinderSystem(
entry = IndexEntry(
 (
 'split',
 'train',
 ),
 (
 'id',
 '8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',
 ),
 (
 'pdb_id',
 '8phr',
 ),
 (
 'cluster_id',
 'cluster_24559_24559',
 ),
 (
 'cluster_id_R',
 'cluster_24559',
 ),
 (
 'cluster_id_L',
 'cluster_24559',
 ),
 (
 'pinder_s',
 False,
 ),
 (
 'pinder_xl',
 False,
 ),
 (
 'pinder_af2',
 False,
 ),
 (
 'uniprot_R',
 'UNDEFINED',
 ),
 (
 'uniprot_L',
 'UNDEFINED',
 ),
 (
 'holo_R_pdb',
 '8phr__X4_UNDEFINED-R.pdb',
 ),
 (
 'holo_L_pdb',
 '8phr__W4_UNDEFINED-L.pdb',
 ),
 (
 'predicted_R_pdb',
 '',
 ),
 (
 'predicted_L_pdb',
 '',
 ),
 (
 'apo_R_pdb',
 '',
 ),
 (
 'apo_L_pdb',
 '',
 ),
 (
 'apo_R_pdbs',
 '',
 ),
 (
 'apo_L_pdbs',
 '',
 ),
 (
 'holo_R',
 True,
 ),
 (
 'holo_L',
 True,
 ),
 (
 'predicted_R',
 False,
 ),
 (
 'predicted_L',
 False,
 ),
 (
 'apo_R',
 False,
 ),
 (
 'apo_L',
 False,
 ),
 (
 'apo_R_quality',
 '',
 ),
 (
 'apo_L_quality',
 '',
 ),
 (
 'chain1_neff',
 10.78125,
 ),
 (
 'chain2_neff',
 11.1171875,
 ),
 (
 

Notice the printed `PinderSystem` object has the following properties:
* `native` - the ground-truth dimer complex
* `holo_receptor` - the receptor chain (monomer) from the ground-truth complex
* `holo_ligand` - the ligand chain (monomer) from the ground-truth complex
* `apo_receptor` - the canonical _apo_ chain (monomer) paired to the receptor chain
* `apo_ligand` - the canonical _apo_ chain (monomer) paired to the ligand chain
* `pred_receptor` - the AlphaFold2 predicted monomer paired to the receptor chain 
* `pred_ligand` - the AlphaFold2 predicted monomer paired to the ligand chain


These properties are pointers to `Structure` objects. The `Structure` object provides the most direct mode of access to structures and associated properties. 

**Note: not all systems have an apo and/or predicted structure for all chains of the ground-truth dimer complex!** 

As was the case in the example above, when the alternative monomers are not available, the property will have a value of `None`. 

You can determine which systems have which alternative monomer pairings _a priori_ by looking at the boolean columns in the index `apo_R` and `apo_L` for the apo receptor and ligand, and `predicted_R` and `predicted_L` for the predicted receptor and ligand, respectively. 


For instance, we can load a different system that _does_ have apo receptor and ligand as such:

In [2]:
apo_system = get_system(train.query('apo_R and apo_L').id.iloc[0])
receptor = apo_system.apo_receptor
ligand = apo_system.apo_ligand 

receptor, ligand


(Structure(
 filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/3wdb__A1_P9WPC9.pdb,
 uniprot_map=/Users/danielkovtun/.local/share/pinder/2024-02/mappings/3wdb__A1_P9WPC9.parquet,
 pinder_id='3wdb__A1_P9WPC9',
 atom_array= with shape (1144,),
 pdb_engine='fastpdb',
 ),
 Structure(
 filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/6ucr__A1_P9WPC9.pdb,
 uniprot_map=/Users/danielkovtun/.local/share/pinder/2024-02/mappings/6ucr__A1_P9WPC9.parquet,
 pinder_id='6ucr__A1_P9WPC9',
 atom_array= with shape (1193,),
 pdb_engine='fastpdb',
 ))

We can now access e.g. the sequence and the coordinates of the structures via the `Structure` objects:

In [3]:
receptor.sequence

'PLGSMFERFTDRARRVVVLAQEEARMLNHNYIGTEHILLGLIHEGEGVAAKSLESLGISLEGVRSQVEEIIGQGQQAPSGHIPFTPRAKKVLELSLREALQLGHNYIGTEHILLGLIREGEGVAAQVLVKLGAELTRVRQQVIQLLSGY'

In [4]:
receptor.coords[0:5]

array([[-12.982, -17.271, -11.271],
 [-14.36 , -17.069, -11.749],
 [-15.261, -16.373, -10.703],
 [-15.461, -15.161, -10.801],
 [-14.842, -18.494, -12.077]], dtype=float32)

We can always access the underyling biotite [AtomArray](https://www.biotite-python.org/latest/apidoc/biotite.structure.AtomArray.html) via the `Structure.atom_array` property:


In [5]:
receptor.atom_array[0:5]

array([
	Atom(np.array([-12.982, -17.271, -11.271], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="N", element="N", b_factor=0.0),
	Atom(np.array([-14.36 , -17.069, -11.749], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="CA", element="C", b_factor=0.0),
	Atom(np.array([-15.261, -16.373, -10.703], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="C", element="C", b_factor=0.0),
	Atom(np.array([-15.461, -15.161, -10.801], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="O", element="O", b_factor=0.0),
	Atom(np.array([-14.842, -18.494, -12.077], dtype=float32), chain_id="R", res_id=2, ins_code="", res_name="PRO", hetero=False, atom_name="CB", element="C", b_factor=0.0)
])

For a more comprehensive overview of all of the `Structure` class properties, refer to the [pinder system](https://pinder-org.github.io/pinder/pinder-system.html) tutorial.


### Using the PinderLoader to load, filter and transform systems

While the `PinderSystem` object provides a self-contained access to structures associated with a dimer system, the `PinderLoader` provides a base abstraction for how to iterate over systems, apply optional filters and/or transforms, and return the systems as an iterator. This construct is covered in a [different tutorial](https://pinder-org.github.io/pinder/pinder-loader.html) tutorial. 

Using the `PinderLoader` is **not** necessary to load systems in your own framework. It is simply one of the provided mechanisms if you find it useful. 

Pinder loader brings together filters, transforms and writers to create a generic `PinderSystem` iterator. It takes either a split name or a list of system IDs as input and can be used to sample alternative monomers to form dimer complexes to serve as e.g. features. 


### Loading a specific split
Note: only the test dataset has a subset defined (`pinder_s, pinder_xl, pinder_af2`)

For train and val, you could just do:
```python
train_loader = PinderLoader(split="train")
val_loader = PinderLoader(split="val")
```


In [6]:
import torch
from pinder.core import PinderLoader
from pinder.core.loader import filters

base_filters = [
 filters.FilterByMissingHolo(),
 filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
 filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
 filters.FilterSubByAtomTypes(min_atom_types=4),
 filters.FilterByHoloOverlap(min_overlap=5),
 filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
 filters.FilterSubRmsds(rmsd_cutoff=7.5),
 filters.FilterDetachedSub(radius=12, max_components=2),
]

loader = PinderLoader(
 split="test", 
 subset="pinder_af2",
 monomer_priority="holo",
 base_filters = base_filters,
 sub_filters = sub_filters
)

loader

PinderLoader(split=test, monomers=holo, systems=180)

In [7]:
len(loader)

180

In [8]:
data = loader[0]
print(f"Data is a {type(data)}")
system, feature_complex, target_complex = data
type(system), type(feature_complex), type(target_complex)

Data is a 


(pinder.core.index.system.PinderSystem,
 pinder.core.loader.structure.Structure,
 pinder.core.loader.structure.Structure)

In [9]:
# You can also use it as an iterator
from tqdm import tqdm
loaded_ids = []
for (system, feature_complex, target_complex) in tqdm(loader):
 loaded_ids.append(system.entry.id)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 180/180 [01:14<00:00, 2.43it/s]


### Loading a specific list of systems


In [10]:
systems = [
 "1df0__A1_Q07009--1df0__B1_Q64537",
 "117e__A1_P00817--117e__B1_P00817",
]
loader = PinderLoader(
 ids=systems,
 monomer_priority="holo",
 base_filters = base_filters,
 sub_filters = sub_filters
)
passing_ids = []
for item in loader:
 passing_ids.append(item[0].entry.id)

systems_removed_by_filters = set(systems) - set(passing_ids)
systems_removed_by_filters

set()

In [11]:
len(systems) == len(passing_ids)

True

### Optional Pinder writer

Without defining a writer for the `PinderLoader`, the loaded systems are available as a tuple of (`PinderSystem`, `Structure`, `Structure`) objects, containing the original `PinderSystem` and the sampled feature and target complexes, respectively. 

If you want to explicitly write the (potentially transformed) structure objects to a custom location or in a custom format (e.g. PDB, pickle, etc.), you can implement a subclass of `PinderWriterBase`. 

The default writer implements writing to PDB files (leveraging the `Structure.to_pdb` method on the structure objects). 



In [12]:
from pinder.core.loader.writer import PinderDefaultWriter

from pathlib import Path
from tempfile import TemporaryDirectory

with TemporaryDirectory() as tmp_dir:
 temp_dir = Path(tmp_dir)
 loader = PinderLoader(
 ids=systems,
 monomer_priority="pred",
 writer=PinderDefaultWriter(temp_dir)
 )
 assert set(loader.index.id) == set(systems)
 for i, r in loader.index.iterrows():
 loaded = loader[i]
 pinder_id = r.id
 system_dir = loader.writer.output_path / pinder_id
 assert system_dir.is_dir()
 print(list(system_dir.glob("af_*.pdb")))


[PosixPath('/var/folders/tt/x223wxwj6dzg3vjjgc_6y5bm0000gn/T/tmpbe5qjtfe/117e__A1_P00817--117e__B1_P00817/af__P00817.pdb')]
[PosixPath('/var/folders/tt/x223wxwj6dzg3vjjgc_6y5bm0000gn/T/tmpbe5qjtfe/1df0__A1_Q07009--1df0__B1_Q64537/af__Q07009.pdb'), PosixPath('/var/folders/tt/x223wxwj6dzg3vjjgc_6y5bm0000gn/T/tmpbe5qjtfe/1df0__A1_Q07009--1df0__B1_Q64537/af__Q64537.pdb')]


## Constructing torch datasets and dataloaders from pinder systems

The remaining sections of this tutorial will be for those interested specifically in torch datasets and dataloaders.

Specifically, we will show how to:
* Implement a PyTorch `Dataset` to interface with pinder data
* Include apo and predicted monomers in the data pipeline, with an option to target specific monomer types or randomly sample from the available types
* Leverage `PinderSystem` and its associated methods to crop apo/predicted monomers to match the ground-truth holo monomers
* Write filters and transforms that operate on `Structure` objects
* Integrate annotations in data filtering and featurization
* Create example features to use for training (you will of course choose your own features) 
* Incorporate diversity sampling in the data loader 


The `pinder.core.loader.dataset` module provides two example implementations of how to integrate the pinder dataset into a torch-based machine learning pipeline.

1. `PinderDataset`: A map-style `torch.utils.data.Dataset` that can be used with torch `DataLoader`'s.
2. `PPIDataset`: A `torch_geometric.data.Dataset` that can be used with torch-geometric `DataLoader`'s. This dataset is designed to be used with the `torch_geometric` package.

Together, the two datasets provide an example implementation of how to abstract away the complexity of loading and processing multiple structures associated with each `PinderSystem` by leveraging the following utilities from pinder:

* `pinder.core.PinderLoader`
* `pinder.core.loader.filters`
* `pinder.core.loader.transforms`

The examples cover two different batch data item structures to illustrate two different use-cases:

* `PinderDataset`: A batch of `(target_complex, feature_complex)` pairs, where `target_complex` and `feature_complex` are `torch.Tensor` objects representing the atomic coordinates and atom types of the holo and sampled (decoy, holo/apo/pred) complexes, respectively.
* `PPIDataset`: A batch of `PairedPDB` objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via `torch_geometric.data.HeteroData`, holding multiple node and/or edge types in disjunct storage objects.


The remaining sections will be split into:
1. Using the `PinderDataset` torch dataset
2. Using the `PPIDataset` torch-geometric dataset
3. How you could implement your own dataset & dataloader


### PinderDataset (torch Dataset)


The `PinderDataset` is an example implementation of a `torch.utils.data.Dataset` that represents its data items as a dict containing the following key, value pairs:
* `target_complex`: The ground-truth holo dimer, represented with a set of default properties encoded as `Tensor`'s
* `feature_complex`: The sampled dimer complex, representing "features", also represented with a set of default properties encoded as `Tensor`'s
* `id`: The pinder ID for the selected system
* `target_id`: The IDs of the receptor and ligand holo monomers, concatenated into a single ID string
* `sample_id`: The IDs of the sampled receptor and ligand holo monomers, concatenated into a single ID string. This can be useful for debugging purposes or generally tracking which specific monomers are selected when targeting alternative monomers (more on this shortly)


Each of the `target_complex` and `feature_complex` values are dictionaries with structural properties encoded by the `pinder.core.loader.geodata.structure2tensor` function by default:
* `atom_coordinates`
* `atom_types`
* `residue_coordinates`
* `residue_types`
* `residue_ids`

You can choose to use a different representation by overriding the default values of `transform` and `target_transform`.
 
It leverages the `PinderLoader` to apply optional filters and/or transforms, provide an interface for sampling alternative monomers, and exposes `transform` and `target_transform` arguments used by the torch Dataset API. 

For more details on the torch Dataset APIs, please refer to the [tutorials](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#datasets-dataloaders).

In [13]:
from pinder.core.loader import filters, transforms
from pinder.core.loader.dataset import PinderDataset

base_filters = [
 filters.FilterByMissingHolo(),
 filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
 filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
 filters.FilterSubByAtomTypes(min_atom_types=4),
 filters.FilterByHoloOverlap(min_overlap=5),
 filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
 filters.FilterSubRmsds(rmsd_cutoff=7.5),
 filters.FilterDetachedSub(radius=12, max_components=2),
]
# We can include Structure-level transforms (and filters) which will operate on the target and feature complexes
structure_transforms = [
 transforms.SelectAtomTypes(atom_types=["CA", "N", "C", "O"])
]
train_dataset = PinderDataset(
 split="train", 
 # We can leverage holo, apo, pred, random and random_mixed monomer sampling strategies
 monomer_priority="random_mixed",
 base_filters = base_filters,
 sub_filters = sub_filters,
 structure_transforms=structure_transforms,
)
assert len(train_dataset) == len(get_index().query('split == "train"'))

train_dataset





### Sampling alternative monomers

The `monomer_priority` argument can be used to target different mixes of bound and unbound monomers to use for creating the decoy/feature complex. 

The allowed values for `monomer_priority` are "apo", "holo", "pred", "random" or "random_mixed".

When `monomer_priority` is set to one of the available monomer types (holo, apo, pred), the same monomer type will be selected for both receptor and ligand.

When the monomer priority is "random", a random monomer type will be selected from the set of monomer types available for both the receptor and ligand. This option ensures the same type of monomer is used for the receptor and ligand.

When the monomer priority is "random_mixed", a random monomer type will be selected for each of receptor and ligand, separately.

Enabling the `fallback_to_holo` option (default) will enable silent fallback to holo when the `monomer_priority` is set to one of apo or pred, but the corresponding monomer is not available for the dimer.

This is useful when only one of receptor or ligand has an unbound monomer, but you wish to include apo or predicted structures in your workflow.

If `fallback_to_holo` is disabled, an error will be raised when the `monomer_priority` is set to one of apo or pred, but the corresponding monomer is not available for the dimer.


By default, when apo monomers are selected, the "canonical" apo monomer is used. Although a single canonical apo monomer should be used for eval, pinder provides multiple apo monomers paired to a single holo monomer (when available). In order to include these non-canonical/alternative monomers, you can specify `use_canonical_apo=False` when constructing the `PinderLoader` or `PinderDataset` objects.


In [14]:
data_item = train_dataset[0]
data_item


{'target_complex': {'atom_types': tensor([[0., 0., 0., ..., 0., 0., 0.],
 [1., 0., 0., ..., 0., 0., 0.],
 [1., 0., 0., ..., 0., 0., 0.],
 ...,
 [1., 0., 0., ..., 0., 0., 0.],
 [1., 0., 0., ..., 0., 0., 0.],
 [0., 0., 1., ..., 0., 0., 0.]]),
 'residue_types': tensor([[16.],
 [16.],
 [16.],
 ...,
 [ 0.],
 [ 0.],
 [ 0.]]),
 'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],
 [132.6810, 428.2520, 163.1550],
 [133.5150, 428.6750, 161.9500],
 ...,
 [177.7620, 463.8650, 166.9020],
 [177.4130, 465.0800, 167.7550],
 [176.8000, 464.9490, 168.8150]]),
 'residue_coordinates': tensor([[131.7500, 429.3090, 163.5360],
 [132.6810, 428.2520, 163.1550],
 [133.5150, 428.6750, 161.9500],
 ...,
 [177.7620, 463.8650, 166.9020],
 [177.4130, 465.0800, 167.7550],
 [176.8000, 464.9490, 168.8150]]),
 'residue_ids': tensor([ 4., 4., 4., ..., 182., 182., 182.])},
 'feature_complex': {'atom_types': tensor([[0., 0., 0., ..., 0., 0., 0.],
 [1., 0., 0., ..., 0., 0., 0.],
 [1., 0., 0., ..., 0., 0., 0.],
 ...,


In [15]:
# Since we used the default option of crop_equal_monomer_shapes, we should expect feature and target complex coords are identical shapes
assert (
 data_item["feature_complex"]["atom_coordinates"].shape
 == data_item["target_complex"]["atom_coordinates"].shape
)

data_item["feature_complex"]["atom_coordinates"].shape

torch.Size([1316, 3])

In [16]:
help(PinderDataset)

Help on class PinderDataset in module pinder.core.loader.dataset:

class PinderDataset(torch.utils.data.dataset.Dataset)
 | PinderDataset(split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = , target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = , ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'
 | 
 | Method resolution order:
 | PinderDataset
 | 

### Torch DataLoader for PinderDataset

The `PinderDataset` can be served by a `torch.utils.data.DataLoader`. 

There is a convenience function `pinder.core.loader.dataset.get_torch_loader` for taking a `PinderDataset` and returning a `DataLoader` for the dataset object. 

We can leverage the default `collate_fn` (`pinder.core.loader.dataset.collate_batch`) to merge multiple systems (`Dataset` items) to create mini-batches of tensors:


In [17]:
from pinder.core.loader.dataset import collate_batch, get_torch_loader
from torch.utils.data import DataLoader

batch_size = 2
train_dataloader = get_torch_loader(
 train_dataset, 
 batch_size=batch_size,
 shuffle=True,
 collate_fn=collate_batch,
 num_workers=0, 
)
assert isinstance(train_dataloader, DataLoader)
assert hasattr(train_dataloader, "dataset")

# Get a batch from the dataloader
batch = next(iter(train_dataloader))

# expected batch dict keys
assert set(batch.keys()) == {
 "target_complex",
 "feature_complex",
 "id",
 "sample_id",
 "target_id",
}
assert isinstance(batch["target_complex"], dict)
assert isinstance(batch["target_complex"]["atom_coordinates"], torch.Tensor)
feature_coords = batch["feature_complex"]["atom_coordinates"]
# Ensure batch size propagates to tensor dims
assert feature_coords.shape[0] == batch_size
# Ensure coordinates have dim 3
assert feature_coords.shape[2] == 3


2024-09-05 14:29:47,942 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=5, items=5
2024-09-05 14:29:49,311 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 1.37s
2024-09-05 14:29:49,381 | pinder.core.loader.structure:595 | ERROR : no common residues found! 2zu1__A1_P03313--7vy5__C41_P03313-L
2024-09-05 14:29:49,636 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=7, items=7
2024-09-05 14:29:49,936 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 0.30s


### Torch geometric Dataset 


In [9]:
# Make sure to install torch_cluster
# !pip install torch_cluster

In [18]:
from pinder.core.loader.dataset import PPIDataset
from pinder.core.loader.geodata import NodeRepresentation

nodes = {NodeRepresentation("atom"), NodeRepresentation("residue")}

train_dataset = PPIDataset(
 node_types=nodes,
 split="train",
 monomer1="holo_receptor",
 monomer2="holo_ligand",
 limit_by=5,
 force_reload=True,
 parallel=False,
)
assert len(train_dataset) == 5

train_dataset



Processing...
2024-09-05 14:29:56,429 | pinder.core.loader.dataset:533 | INFO : Finished processing, only 5 systems
Done!


PPIDataset(5)

In [19]:
from torch_geometric.data import HeteroData
from pinder.core import get_index

# Here we check that we correctly capped the number of systems to load to 5 (via limit_by)
pindex = get_index()
raw_ids = set(train_dataset.raw_file_names)
assert len(raw_ids.intersection(set(pindex.id))) == 5
processed_ids = {f.stem for f in train_dataset.processed_file_names}
# Here we ensure that all 5 ids got processed and saved as .pt file on disk
assert len(processed_ids.intersection(set(pindex.id))) == 5

# Let's get an item from the dataset by index 
data_item = train_dataset[0]
assert isinstance(data_item, HeteroData)
data_item



PairedPDB(
 ligand_residue={
 residueid=[158, 1],
 pos=[158, 3],
 edge_index=[2, 1580],
 chain=[1],
 },
 receptor_residue={
 residueid=[171, 1],
 pos=[171, 3],
 edge_index=[2, 1710],
 chain=[1],
 },
 ligand_atom={
 x=[1198, 12],
 pos=[1198, 3],
 edge_index=[2, 11980],
 },
 receptor_atom={
 x=[1358, 12],
 pos=[1358, 3],
 edge_index=[2, 13580],
 },
 pdb={
 id=[1],
 num_nodes=1,
 }
)

In [20]:
# We can also get an item by its system ID 
data_item = train_dataset.get_filename("8phr__X4_UNDEFINED--8phr__W4_UNDEFINED")
data_item



PairedPDB(
 ligand_residue={
 residueid=[158, 1],
 pos=[158, 3],
 edge_index=[2, 1580],
 chain=[1],
 },
 receptor_residue={
 residueid=[171, 1],
 pos=[171, 3],
 edge_index=[2, 1710],
 chain=[1],
 },
 ligand_atom={
 x=[1198, 12],
 pos=[1198, 3],
 edge_index=[2, 11980],
 },
 receptor_atom={
 x=[1358, 12],
 pos=[1358, 3],
 edge_index=[2, 13580],
 },
 pdb={
 id=[1],
 num_nodes=1,
 }
)

In [21]:
train_dataset.print_summary()


PPIDataset (#graphs=5):
+------------+----------+----------+
| | #nodes | #edges |
|------------+----------+----------|
| mean | 2612.4 | 0 |
| std | 1631.6 | 0 |
| min | 999 | 0 |
| quantile25 | 1600 | 0 |
| median | 2343 | 0 |
| quantile75 | 2886 | 0 |
| max | 5234 | 0 |
+------------+----------+----------+
Number of nodes per node type:
+------------+------------------+--------------------+---------------+-----------------+-------+
| | ligand_residue | receptor_residue | ligand_atom | receptor_atom | pdb |
|------------+------------------+--------------------+---------------+-----------------+-------|
| mean | 140.2 | 152.8 | 1103.4 | 1215 | 1 |
| std | 90.2 | 87.5 | 738.9 | 717.1 | 0 |
| min | 58 | 62 | 426 | 452 | 1 |
| quantile25 | 78 | 97 | 622 | 802 | 1 |
| median | 121 | 144 | 958 | 1119 | 1 |
| quantile75 | 158 | 171 | 1198 | 1358 | 1 |
| max | 286 | 290 | 2313 | 2344 | 1 |
+------------+------------------+--------------------+---------------+-----------------+-------+


### PairedPDB torch-geometric HeteroData object

The `PPIDataset` represents its data items as `PairedPDB` objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via `torch_geometric.data.HeteroData`, holding multiple node and/or edge types in disjunct storage objects.

It leverages the `PinderLoader` to apply optional filters and/or transforms and implements a caching system by saving processed systems to disk in `.pt` files.
The `PairedPDB` implements a conversion method `.from_pinder_system` that takes in a `PinderSystem` and converts it to a `PairedPDB` object. 

For more details on the torch-geometric Dataset APIs, please refer to the [tutorials](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html).

In [22]:
from pinder.core.loader.geodata import PairedPDB
from torch_geometric.data import HeteroData

pinder_id = "3s9d__B1_P48551--3s9d__A1_P01563"
system = PinderSystem(pinder_id)

holo_data = PairedPDB.from_pinder_system(
 system=system,
 monomer1="holo_receptor", monomer2="holo_ligand",
 node_types=nodes,
)
assert isinstance(holo_data, HeteroData)
expected_node_types = [
 'ligand_residue', 'receptor_residue', 'ligand_atom', 'receptor_atom'
]
assert holo_data.num_nodes == 2780
assert holo_data.num_edges == 0
assert isinstance(holo_data.num_node_features, dict)
expected_num_feats = {
 'ligand_residue': 0,
 'receptor_residue': 0,
 'ligand_atom': 12,
 'receptor_atom': 12
}
for k, v in expected_num_feats.items():
 assert holo_data.num_node_features[k] == v

assert holo_data.node_types == expected_node_types


holo_data



PairedPDB(
 ligand_residue={
 residueid=[127, 1],
 pos=[127, 3],
 edge_index=[2, 1270],
 chain=[1],
 },
 receptor_residue={
 residueid=[180, 1],
 pos=[180, 3],
 edge_index=[2, 1800],
 chain=[1],
 },
 ligand_atom={
 x=[1032, 12],
 pos=[1032, 3],
 edge_index=[2, 10320],
 },
 receptor_atom={
 x=[1441, 12],
 pos=[1441, 3],
 edge_index=[2, 14410],
 }
)

In [23]:
# You can target specific monomers (apo/holo/pred) for the receptor and ligand
apo_data = PairedPDB.from_pinder_system(
 system=system,
 monomer1="apo_receptor", monomer2="apo_ligand",
 node_types=nodes,
)
assert isinstance(apo_data, HeteroData)

assert apo_data.num_nodes == 3437
assert apo_data.num_edges == 0
assert isinstance(apo_data.num_node_features, dict)
expected_num_feats = {
 'ligand_residue': 0,
 'receptor_residue': 0,
 'ligand_atom': 12,
 'receptor_atom': 12
}
for k, v in expected_num_feats.items():
 assert apo_data.num_node_features[k] == v

assert apo_data.node_types == expected_node_types

apo_data

PairedPDB(
 ligand_residue={
 residueid=[165, 1],
 pos=[165, 3],
 edge_index=[2, 1650],
 chain=[1],
 },
 receptor_residue={
 residueid=[212, 1],
 pos=[212, 3],
 edge_index=[2, 2120],
 chain=[1],
 },
 ligand_atom={
 x=[1350, 12],
 pos=[1350, 3],
 edge_index=[2, 13500],
 },
 receptor_atom={
 x=[1710, 12],
 pos=[1710, 3],
 edge_index=[2, 17100],
 }
)

### Torch geometric DataLoader 

The `PPIDataset` can be served by a `torch_geometric.DataLoader`. 

There is a convenience function `pinder.core.loader.dataset.get_geo_loader` for taking a PPIDataset and returning a `DataLoader` for the dataset object. 



In [24]:
from pinder.core.loader.dataset import get_geo_loader
from torch_geometric.loader import DataLoader


loader = get_geo_loader(train_dataset)

assert isinstance(loader, DataLoader)
assert hasattr(loader, "dataset")
ds = loader.dataset
assert len(ds) == 5


loader



In [25]:
ds

PPIDataset(5)

## Implementing your own PyTorch Dataset & DataLoader for pinder


While the previous sections covered two example implementations of leveraging the torch and torch-geometric APIs, below we will illustrate how you can write your own PyTorch data pipeline to feed your model. 

We will focus on writing a minimalistic example that simply fetches PDB files and returns the coordinates as tensors. We will also include an example of a sampler function for the dataloader to implement diversity sampling in your workflow.



### Defining the Dataset

Below we will write a barebones `torch.utils.data.Dataset` object that implements at a minimum:
* `__init__` method 
* `__len__` method returning the number of items in the dataset
* `__getitem__` method that returns an item in the dataset

We will also add an option to include apo and predicted monomer types by adding a `monomer_priority` argument to our dataset. The argument will be set to "random" by default, indicating that we want to randomly sample a monomer type from the set of monomers available for a given system. We could also target a specific monomer type by setting this argument to one of "holo", "apo", "predicted". Not every system in the training set has apo or predicted structures available, so we will also add a `fallback_to_holo` argument to indicate whether we want to use `holo` monomers when the selected monomer type is not available. 

We also define two interfaces for applying filters to our dataset:
1. `metadata_filter`: a query string to apply to the pinder metadata pandas DataFrame
2. `system_filters: list[PinderFilterBase]`: a list of filters that inheret a base class, `PinderFilterBase`, which serves as the abstraction layer for defining `PinderSystem`-based filters
3. `structure_filters: list[StructureFilter]`: a list of filters that inheret a base class, `StructureFilter`, which serves as the abstraction layer for defining `Structure`-based filters



In [26]:
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import Dataset
from pinder.core import get_index, get_metadata
from pinder.core.loader.loader import _create_target_feature_complex, select_monomer
from pinder.core.loader.structure import Structure

index = get_index()
metadata = get_metadata()


class CustomPinderDataset(Dataset):
 def __init__(
 self, 
 split: str, 
 monomer_priority: str = "random", 
 fallback_to_holo: bool = True, 
 crop_equal_monomer_shapes: bool = True,
 use_canonical_apo: bool = True,
 transform=None, 
 target_transform=None,
 structure_filters: list[filters.StructureFilter] = [],
 system_filters: list[filters.PinderFilterBase] = [],
 metadata_filter: str | None = None,
 max_load_attempts: int = 10,
 ) -> None:
 # Which split/mode we are using for the dataset instance
 self.split = split
 self.monomer_priority = monomer_priority
 self.fallback_to_holo = fallback_to_holo
 self.crop_equal_monomer_shapes = crop_equal_monomer_shapes

 # Optional transform and target transform to apply (will be covered shortly)
 self.transform = transform
 self.target_transform = target_transform
 # Optional system-level filters to apply
 self.system_filters = system_filters
 # Optional structure filters to apply
 self.structure_filters = structure_filters

 # Maximum number of times to try sampling another index from the dataset until an exception is raised
 self.max_load_attempts = max_load_attempts
 # Whether we should use canonical apo structures (apo_R/L_pdb columns in pinder index) if apo monomers are selected
 self.use_canonical_apo = use_canonical_apo
 
 # Define the subset of the pinder index and metadata corresponding to the split of our dataset instance 
 self.index = index.query(f'split == "{split}"').reset_index(drop=True)
 self.metadata = metadata[metadata["id"].isin(set(self.index.id))].reset_index(drop=True)
 if metadata_filter:
 try:
 self.metadata = self.metadata.query(metadata_filter).reset_index(drop=True)
 except Exception as e:
 print(f"Failed to apply metadata_filter={metadata_filter}: {e}")
 
 self.index = self.index[self.index["id"].isin(set(self.metadata.id))].reset_index(drop=True)

 def __len__(self):
 return len(self.index)
 
 def __getitem__(self, idx: int) -> tuple[NDArray[np.double], NDArray[np.double]]:
 valid_idx = False
 attempts = 0
 while not valid_idx and attempts < self.max_load_attempts:
 attempts += 1
 row = self.index.iloc[idx]
 system = PinderSystem(row.id)
 
 system = self.apply_system_filters(system)
 if not isinstance(system, PinderSystem):
 continue

 selected_monomers = select_monomer(
 row,
 self.monomer_priority,
 self.fallback_to_holo,
 self.use_canonical_apo,
 )
 # With the system and selected_monomers objects, we can now create a pair of dimer complexes
 # Below we leverage the existing utility from the PinderLoader (_create_target_feature_complex)
 target_complex, feature_complex = _create_target_feature_complex(
 system, selected_monomers, self.crop_equal_monomer_shapes, self.fallback_to_holo
 )
 valid_idx = self.apply_structure_filters(target_complex)
 if not valid_idx:
 # Try another index before raising IndexError
 idx = random.choice(list(range(len(self))))

 if not valid_idx:
 raise IndexError(
 f"Unable to find a valid item in the dataset satisfying filters at {idx} after {attempts} attempts!"
 )
 if self.transform is not None:
 feature_complex = self.transform(feature_complex)
 if self.target_transform is not None:
 target_complex = self.target_transform(target_complex)
 return feature_complex, target_complex

 def apply_structure_filters(self, structure: Structure) -> bool:
 pass_filters = True
 for structure_filter in self.structure_filters:
 if not structure_filter(structure):
 pass_filters = False
 break
 return pass_filters

 def apply_system_filters(self, system: PinderSystem) -> PinderSystem | bool:
 for system_filter in self.system_filters:
 if isinstance(system_filter, filters.PinderFilterBase):
 if not base_filter(system):
 return False
 return system

 def __repr__(self) -> str:
 return f"CustomPinderDataset(split={self.split}, monomers={self.monomer_priority}, systems={len(self)})"



In [27]:
# Note the selected monomers indicated by Structure.pinder_id attributes. Since we enabled cropping, the feature and target complex AtomArray have identical shapes 
test_data = CustomPinderDataset(split="test")
test_data[0]

(Structure(
 filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/af__A0A229LVN5--af__A0A229LVN5.pdb,
 uniprot_map=None,
 pinder_id='af__A0A229LVN5--af__A0A229LVN5',
 atom_array= with shape (2092,),
 pdb_engine='fastpdb',
 ),
 Structure(
 filepath=/Users/danielkovtun/.local/share/pinder/2024-02/test_set_pdbs/7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L.pdb,
 uniprot_map= with shape (294, 14),
 pinder_id='7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L',
 atom_array= with shape (2092,),
 pdb_engine='fastpdb',
 ))

In the above example, we wrote a torch Dataset that currently returns a tuple of `Structure` objects (one representing the "decoy" or sample to use for features and the other representing the ground-truth "label")

Of course we wouldn't use these `Structure` objects directly in our model. In your model, you will have to choose which features to compute and which data structure to use to represent them. 

While the `PinderDataset` and `PPIDataset` datasets provide examples of feature encodings, below we will adjust the `transform` and `target_transform` objects to simply return NumPy NDArray objects as a default. Note: you can't pass `Structure` objects to torch DataLoader, you must first convert them into array-like structures supported by the default torch collate_fn. You can implement your own collate functions to adjust this behavior.

In [28]:
def default_transform(structure: Structure) -> NDArray[np.double]:
 return structure.coords

In [29]:
test_data = CustomPinderDataset(split="test", transform=default_transform, target_transform=default_transform)
test_data[0]

(array([[-12.6210985, -9.128864 , 17.258345 ],
 [-13.660538 , -8.840154 , 16.265753 ],
 [-13.38884 , -7.5286074, 15.524057 ],
 ...,
 [ 7.1754823, -19.776093 , 21.191608 ],
 [ 9.592421 , -19.601936 , 19.197937 ],
 [ 9.432583 , -17.757698 , 20.458267 ]], dtype=float32),
 array([[-13.215324 , -11.076905 , 15.214827 ],
 [-14.133494 , -10.301853 , 14.386162 ],
 [-13.614427 , -8.882867 , 14.150835 ],
 ...,
 [ 6.8049664, -15.96424 , 20.159506 ],
 [ 9.545281 , -16.992254 , 18.400843 ],
 [ 7.0213795, -15.329907 , 18.152586 ]], dtype=float32))

### Implementing diversity sampling

Here, we provide an example of how one might use `torch.utils.data.WeightedRandomSampler`. However, users are free to sample diversity any way they see fit. For this example, we are going to sample diversity inversely proportional to pinder cluster population size. 


In [30]:
from torch.utils.data import WeightedRandomSampler

def inverse_cluster_size_sampler(dataset: PinderDataset, replacement: bool = True):
 index = dataset.index
 cluster_counts = (
 index["cluster_id"].value_counts().rename("cluster_count")
 )
 index = index.merge(
 cluster_counts, left_on="cluster_id", right_index=True
 )
 # undersample large clusters
 cluster_weights = 1.0 / torch.tensor(index.cluster_count.values)
 return WeightedRandomSampler(
 weights=cluster_weights,
 num_samples=len(
 cluster_counts
 ),
 replacement=replacement,
 )

sampler = inverse_cluster_size_sampler(
 test_data,
 replacement=True,
)
sampler




### Defining the dataloader 

Now that we have implemented a dataset and sampling function, we can tie everything together to implement the `DataLoader`.


In [31]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(
 test_data, 
 batch_size=1, 
 # Mutually exclusive with sampler
 shuffle=False, 
 sampler=sampler,
)
test_features, test_labels = next(iter(test_dataloader))
test_features, test_labels



(tensor([[[-2.5430e+00, 1.7731e+01, -2.8000e-02],
 [-2.4430e+00, 1.6607e+01, 9.4600e-01],
 [-2.3810e+00, 1.7162e+01, 2.3870e+00],
 ...,
 [-3.8556e+01, -1.1730e+00, -1.2576e+01],
 [-3.8980e+01, -7.6000e-02, -1.2154e+01],
 [-3.9255e+01, -1.9450e+00, -1.3256e+01]]]),
 tensor([[[ 12.3608, -5.7149, 7.7741],
 [ 12.8355, -5.7361, 9.1869],
 [ 12.1397, -6.8793, 9.9598],
 ...,
 [ 29.1573, -15.5614, 2.8228],
 [ 29.4802, -14.5300, 2.1956],
 [ 27.9742, -15.8473, 3.0789]]]))

Putting it all together, we can now get a train/val/test dataloader as such:



In [32]:
from typing import Any, Callable


def get_loader(
 dataset: CustomPinderDataset,
 sampler: torch.utils.data.Sampler | None = inverse_cluster_size_sampler,
 batch_size: int = 2,
 # shuffle is mutually exclusive with sampler
 shuffle: bool = False,
 num_workers: int = 0,
 collate_fn: Callable[[list[tuple[NDArray[np.double], NDArray[np.double]]]], tuple[torch.Tensor, torch.Tensor]] | None = None,
 **kwargs: Any,
) -> "DataLoader[CustomPinderDataset]":
 return DataLoader(
 dataset,
 batch_size=batch_size,
 shuffle=shuffle,
 num_workers=num_workers,
 sampler=sampler,
 collate_fn=collate_fn,
 **kwargs,
 )


train_data = CustomPinderDataset(
 split="train", 
 structure_filters=[filters.MinAtomTypesFilter()], 
 metadata_filter="(buried_sasa >= 500)",
 transform=default_transform, 
 target_transform=default_transform,
)
train_dataloader = get_loader(
 train_data, 
 sampler=inverse_cluster_size_sampler(train_data),
 batch_size=1,
)
train_dataloader




In [33]:
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels


2024-09-05 14:30:01,983 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=7, items=7
2024-09-05 14:30:03,261 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 1.28s


(tensor([[[125.8500, 300.1720, 233.1940],
 [124.5070, 299.9340, 232.7260],
 [124.1540, 301.1870, 231.9160],
 ...,
 [113.6850, 300.0020, 243.7090],
 [114.5000, 301.0400, 243.7710],
 [119.5730, 297.7760, 244.9850]]]),
 tensor([[[125.8500, 300.1720, 233.1940],
 [124.5070, 299.9340, 232.7260],
 [124.1540, 301.1870, 231.9160],
 ...,
 [113.6850, 300.0020, 243.7090],
 [114.5000, 301.0400, 243.7710],
 [119.5730, 297.7760, 244.9850]]]))

### Using a larger batch size (collate_fn)
What if we want to use a larger batch size?

By default, the `collate_fn` used by the DataLoader is `torch.utils.data._utils.collate.default_collate` which expects to be able to stack tensors via `torch.stack`.

Since different pinder systems have a differing number of atomic coordinates, using a batch size greater than 1 will cause this function to raise a RuntimeError with a message like:
`RuntimeError: stack expects each tensor to be equal size, but got [688, 3] at entry 0 and [1391, 3] at entry 1`


We can leverage existing pinder utilities to adapt our default collate_fn to pad tensor dimensions with dummy values so that they can be stacked. 



In [34]:
from pinder.core.loader.dataset import pad_and_stack

help(pad_and_stack)


Help on function pad_and_stack in module pinder.core.loader.dataset:

pad_and_stack(tensors: 'list[Tensor]', dim: 'int' = 0, dims_to_pad: 'list[int] | None' = None, value: 'int | float | None' = None) -> 'Tensor'
 Pads a list of tensors to the maximum length observed along each dimension and then stacks them along a new dimension (given by `dim`).
 
 Parameters:
 tensors (list[Tensor]): A list of tensors to pad and stack
 dim (int): The new dimension to stack along.
 dims_to_pad (list[int] | None): The dimensions to pad
 value (int | float | None, optional): The value to pad with, by default None
 
 Returns:
 Tensor: The padded and stacked tensor. Below are examples of input and output shapes
 Example 1: Sequence features (although redundant with torch.rnn.utils.pad_sequence)
 input: [(2,), (7,)], dim: 0
 output: (2, 7)
 Example 2: Pair features (e.g., pairwise coordinates)
 input: [(4, 4, 3), (7, 7, 3)], dim: 0
 output: (2, 7, 7, 3)



In [35]:
def collate_coordinates(batch, coords_pad_value: int = -100):
 feature_coords = []
 target_coords = []
 for x in batch:
 feat, target = x
 if isinstance(feat, np.ndarray):
 feat = torch.tensor(feat, dtype=torch.float32)
 if isinstance(target, np.ndarray):
 target = torch.tensor(target, dtype=torch.float32)
 feature_coords.append(feat)
 target_coords.append(target)

 feature_coords = pad_and_stack(feature_coords, dim=0, value=coords_pad_value) 
 target_coords = pad_and_stack(target_coords, dim=0, value=coords_pad_value) 
 return feature_coords, target_coords


train_dataloader = get_loader(
 train_data, 
 sampler=inverse_cluster_size_sampler(train_data),
 collate_fn=collate_coordinates,
 batch_size=2,
)
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels


2024-09-05 14:30:03,504 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=7, items=7
2024-09-05 14:30:04,679 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 1.17s
2024-09-05 14:30:07,053 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=6, items=6
2024-09-05 14:30:07,242 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 0.19s


(tensor([[[ 173.7118, 222.4612, 146.3869],
 [ 173.7950, 221.7822, 147.6740],
 [ 175.1321, 221.0380, 147.7151],
 ...,
 [ 162.2876, 298.3475, 158.2268],
 [ 163.3745, 297.5005, 157.0505],
 [ 164.2955, 296.3959, 158.1551]],
 
 [[ 274.2440, 238.6180, 254.4800],
 [ 275.1950, 238.5360, 253.3790],
 [ 274.5000, 238.5590, 252.0220],
 ...,
 [-100.0000, -100.0000, -100.0000],
 [-100.0000, -100.0000, -100.0000],
 [-100.0000, -100.0000, -100.0000]]]),
 tensor([[[ 179.8020, 204.6620, 163.5760],
 [ 179.3740, 203.2420, 163.3880],
 [ 180.6090, 202.3680, 163.1150],
 ...,
 [ 185.2960, 297.5130, 155.7020],
 [ 183.4960, 297.2930, 155.6730],
 [ 182.9220, 298.8890, 156.2510]],
 
 [[ 274.2440, 238.6180, 254.4800],
 [ 275.1950, 238.5360, 253.3790],
 [ 274.5000, 238.5590, 252.0220],
 ...,
 [-100.0000, -100.0000, -100.0000],
 [-100.0000, -100.0000, -100.0000],
 [-100.0000, -100.0000, -100.0000]]]))