{ "cells": [ { "cell_type": "markdown", "id": "b279501d-0553-4756-a611-08ade109f1de", "metadata": {}, "source": [ "# Pinder loader\n", "\n", "\n", "The goal of this tutorial is to provide some hands-on examples of how one can leverage the `pinder` dataset in their ML workflow. Specifically, we will illustrate how you can use various utilities provided by `pinder` to write your own data pipeline. \n", "\n", "While this tutorial will not go into details about how to write your own model, it will cover the basic groundwork necessary to interface with structures in pinder and the associated splits and metadata. You will of course want to implement your own featurization pipelines, data representations, etc. but the hope is that this tutorial clarifies how to access the data and make use of it in an ML framework. \n", "\n", "\n", "Before proceeding with this tutorial section, you may find it helpful to review the existing tutorials available in `pinder`. \n", "\n", "Specifcially, the tutorials covering:\n", "* [pinder index](https://pinder-org.github.io/pinder/pinder-index.html)\n", "* [pinder system](https://pinder-org.github.io/pinder/pinder-system.html)\n", "* [cropped superposition](https://pinder-org.github.io/pinder/superposition.html)\n" ] }, { "cell_type": "markdown", "id": "0aac6d07-925c-4a65-88bd-528a19b9853a", "metadata": {}, "source": [ "## Accessing and loading data for training\n", "\n", "In order to access the train and val splits for PINDER, please refer to the [pinder documentation](https://github.com/pinder-org/pinder/tree/main?tab=readme-ov-file#%EF%B8%8F-getting-the-dataset)\n", "\n", "Once you have downloaded the pinder dataset, either via the `pinder` package or directly through `gsutil`, you will have all of the necessary files for training. \n", "\n", "To get a list of those systems and their split labels, refer to the `pinder` index. \n", "\n", "**We will start by looking at the most basic way to load items from the training and validation set: via `PinderSystem` objects**" ] }, { "cell_type": "markdown", "id": "42fa57d9-6621-4b9d-be6a-a06488aacdd7", "metadata": {}, "source": [ "### Recap: PinderSystem and Structure classes" ] }, { "cell_type": "code", "execution_count": 1, "id": "f6cfd590-b725-4fca-8793-ae7b0a52f312", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PinderSystem(\n", "entry = IndexEntry(\n", " (\n", " 'split',\n", " 'train',\n", " ),\n", " (\n", " 'id',\n", " '8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',\n", " ),\n", " (\n", " 'pdb_id',\n", " '8phr',\n", " ),\n", " (\n", " 'cluster_id',\n", " 'cluster_24559_24559',\n", " ),\n", " (\n", " 'cluster_id_R',\n", " 'cluster_24559',\n", " ),\n", " (\n", " 'cluster_id_L',\n", " 'cluster_24559',\n", " ),\n", " (\n", " 'pinder_s',\n", " False,\n", " ),\n", " (\n", " 'pinder_xl',\n", " False,\n", " ),\n", " (\n", " 'pinder_af2',\n", " False,\n", " ),\n", " (\n", " 'uniprot_R',\n", " 'UNDEFINED',\n", " ),\n", " (\n", " 'uniprot_L',\n", " 'UNDEFINED',\n", " ),\n", " (\n", " 'holo_R_pdb',\n", " '8phr__X4_UNDEFINED-R.pdb',\n", " ),\n", " (\n", " 'holo_L_pdb',\n", " '8phr__W4_UNDEFINED-L.pdb',\n", " ),\n", " (\n", " 'predicted_R_pdb',\n", " '',\n", " ),\n", " (\n", " 'predicted_L_pdb',\n", " '',\n", " ),\n", " (\n", " 'apo_R_pdb',\n", " '',\n", " ),\n", " (\n", " 'apo_L_pdb',\n", " '',\n", " ),\n", " (\n", " 'apo_R_pdbs',\n", " '',\n", " ),\n", " (\n", " 'apo_L_pdbs',\n", " '',\n", " ),\n", " (\n", " 'holo_R',\n", " True,\n", " ),\n", " (\n", " 'holo_L',\n", " True,\n", " ),\n", " (\n", " 'predicted_R',\n", " False,\n", " ),\n", " (\n", " 'predicted_L',\n", " False,\n", " ),\n", " (\n", " 'apo_R',\n", " False,\n", " ),\n", " (\n", " 'apo_L',\n", " False,\n", " ),\n", " (\n", " 'apo_R_quality',\n", " '',\n", " ),\n", " (\n", " 'apo_L_quality',\n", " '',\n", " ),\n", " (\n", " 'chain1_neff',\n", " 10.78125,\n", " ),\n", " (\n", " 'chain2_neff',\n", " 11.1171875,\n", " ),\n", " (\n", " 'chain_R',\n", " 'X4',\n", " ),\n", " (\n", " 'chain_L',\n", " 'W4',\n", " ),\n", " (\n", " 'contains_antibody',\n", " False,\n", " ),\n", " (\n", " 'contains_antigen',\n", " False,\n", " ),\n", " (\n", " 'contains_enzyme',\n", " False,\n", " ),\n", ")\n", "native=Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/8phr__X4_UNDEFINED--8phr__W4_UNDEFINED.pdb,\n", " uniprot_map=None,\n", " pinder_id='8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',\n", " atom_array= with shape (2556,),\n", " pdb_engine='fastpdb',\n", ")\n", "holo_receptor=Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/8phr__X4_UNDEFINED-R.pdb,\n", " uniprot_map=/Users/danielkovtun/.local/share/pinder/2024-02/mappings/8phr__X4_UNDEFINED-R.parquet,\n", " pinder_id='8phr__X4_UNDEFINED-R',\n", " atom_array= with shape (1358,),\n", " pdb_engine='fastpdb',\n", ")\n", "holo_ligand=Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/8phr__W4_UNDEFINED-L.pdb,\n", " uniprot_map=/Users/danielkovtun/.local/share/pinder/2024-02/mappings/8phr__W4_UNDEFINED-L.parquet,\n", " pinder_id='8phr__W4_UNDEFINED-L',\n", " atom_array= with shape (1198,),\n", " pdb_engine='fastpdb',\n", ")\n", "apo_receptor=None\n", "apo_ligand=None\n", "pred_receptor=None\n", "pred_ligand=None\n", ")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "from pinder.core import get_index, PinderSystem\n", "\n", "def get_system(system_id: str) -> PinderSystem:\n", " return PinderSystem(system_id)\n", "\n", "\n", "index = get_index()\n", "train = index[index.split == \"train\"].copy()\n", "system = get_system(train.id.iloc[0])\n", "system\n", " " ] }, { "cell_type": "markdown", "id": "ae5e509b-f62d-4c02-9424-46d2be3bf1f0", "metadata": {}, "source": [ "Notice the printed `PinderSystem` object has the following properties:\n", "* `native` - the ground-truth dimer complex\n", "* `holo_receptor` - the receptor chain (monomer) from the ground-truth complex\n", "* `holo_ligand` - the ligand chain (monomer) from the ground-truth complex\n", "* `apo_receptor` - the canonical _apo_ chain (monomer) paired to the receptor chain\n", "* `apo_ligand` - the canonical _apo_ chain (monomer) paired to the ligand chain\n", "* `pred_receptor` - the AlphaFold2 predicted monomer paired to the receptor chain \n", "* `pred_ligand` - the AlphaFold2 predicted monomer paired to the ligand chain\n", "\n", "\n", "These properties are pointers to `Structure` objects. The `Structure` object provides the most direct mode of access to structures and associated properties. \n", "\n", "**Note: not all systems have an apo and/or predicted structure for all chains of the ground-truth dimer complex!** \n", "\n", "As was the case in the example above, when the alternative monomers are not available, the property will have a value of `None`. \n", "\n", "You can determine which systems have which alternative monomer pairings _a priori_ by looking at the boolean columns in the index `apo_R` and `apo_L` for the apo receptor and ligand, and `predicted_R` and `predicted_L` for the predicted receptor and ligand, respectively. \n", "\n", "\n", "For instance, we can load a different system that _does_ have apo receptor and ligand as such:" ] }, { "cell_type": "code", "execution_count": 2, "id": "73287b1a-425f-44ec-8153-cfe85479db60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/3wdb__A1_P9WPC9.pdb,\n", " uniprot_map=/Users/danielkovtun/.local/share/pinder/2024-02/mappings/3wdb__A1_P9WPC9.parquet,\n", " pinder_id='3wdb__A1_P9WPC9',\n", " atom_array= with shape (1144,),\n", " pdb_engine='fastpdb',\n", " ),\n", " Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/6ucr__A1_P9WPC9.pdb,\n", " uniprot_map=/Users/danielkovtun/.local/share/pinder/2024-02/mappings/6ucr__A1_P9WPC9.parquet,\n", " pinder_id='6ucr__A1_P9WPC9',\n", " atom_array= with shape (1193,),\n", " pdb_engine='fastpdb',\n", " ))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "apo_system = get_system(train.query('apo_R and apo_L').id.iloc[0])\n", "receptor = apo_system.apo_receptor\n", "ligand = apo_system.apo_ligand \n", "\n", "receptor, ligand\n" ] }, { "cell_type": "markdown", "id": "6a890cd2-cb92-4be5-aa05-f87b18e79e25", "metadata": {}, "source": [ "We can now access e.g. the sequence and the coordinates of the structures via the `Structure` objects:" ] }, { "cell_type": "code", "execution_count": 3, "id": "d78915a1-2f11-45cd-8c03-0cfa77a93e62", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'PLGSMFERFTDRARRVVVLAQEEARMLNHNYIGTEHILLGLIHEGEGVAAKSLESLGISLEGVRSQVEEIIGQGQQAPSGHIPFTPRAKKVLELSLREALQLGHNYIGTEHILLGLIREGEGVAAQVLVKLGAELTRVRQQVIQLLSGY'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "receptor.sequence" ] }, { "cell_type": "code", "execution_count": 4, "id": "b53305d5-f54d-4130-9536-11d1de44d9ba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-12.982, -17.271, -11.271],\n", " [-14.36 , -17.069, -11.749],\n", " [-15.261, -16.373, -10.703],\n", " [-15.461, -15.161, -10.801],\n", " [-14.842, -18.494, -12.077]], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "receptor.coords[0:5]" ] }, { "cell_type": "markdown", "id": "2ac798a7-2dd2-4200-a1bb-2f261df52d75", "metadata": {}, "source": [ "We can always access the underyling biotite [AtomArray](https://www.biotite-python.org/latest/apidoc/biotite.structure.AtomArray.html) via the `Structure.atom_array` property:\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "fd045572-b95f-4a75-b086-8224e5c9e6d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([\n", "\tAtom(np.array([-12.982, -17.271, -11.271], dtype=float32), chain_id=\"R\", res_id=2, ins_code=\"\", res_name=\"PRO\", hetero=False, atom_name=\"N\", element=\"N\", b_factor=0.0),\n", "\tAtom(np.array([-14.36 , -17.069, -11.749], dtype=float32), chain_id=\"R\", res_id=2, ins_code=\"\", res_name=\"PRO\", hetero=False, atom_name=\"CA\", element=\"C\", b_factor=0.0),\n", "\tAtom(np.array([-15.261, -16.373, -10.703], dtype=float32), chain_id=\"R\", res_id=2, ins_code=\"\", res_name=\"PRO\", hetero=False, atom_name=\"C\", element=\"C\", b_factor=0.0),\n", "\tAtom(np.array([-15.461, -15.161, -10.801], dtype=float32), chain_id=\"R\", res_id=2, ins_code=\"\", res_name=\"PRO\", hetero=False, atom_name=\"O\", element=\"O\", b_factor=0.0),\n", "\tAtom(np.array([-14.842, -18.494, -12.077], dtype=float32), chain_id=\"R\", res_id=2, ins_code=\"\", res_name=\"PRO\", hetero=False, atom_name=\"CB\", element=\"C\", b_factor=0.0)\n", "])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "receptor.atom_array[0:5]" ] }, { "cell_type": "markdown", "id": "fc63286c-958a-4106-a2a9-6a76c5afcaf4", "metadata": {}, "source": [ "For a more comprehensive overview of all of the `Structure` class properties, refer to the [pinder system](https://pinder-org.github.io/pinder/pinder-system.html) tutorial.\n" ] }, { "cell_type": "markdown", "id": "146046b4-cf0e-4a00-9540-6155f0085c5c", "metadata": {}, "source": [ "### Using the PinderLoader to load, filter and transform systems\n", "\n", "While the `PinderSystem` object provides a self-contained access to structures associated with a dimer system, the `PinderLoader` provides a base abstraction for how to iterate over systems, apply optional filters and/or transforms, and return the systems as an iterator. This construct is covered in a [different tutorial](https://pinder-org.github.io/pinder/pinder-loader.html) tutorial. \n", "\n", "Using the `PinderLoader` is **not** necessary to load systems in your own framework. It is simply one of the provided mechanisms if you find it useful. \n", "\n", "Pinder loader brings together filters, transforms and writers to create a generic `PinderSystem` iterator. It takes either a split name or a list of system IDs as input and can be used to sample alternative monomers to form dimer complexes to serve as e.g. features. \n" ] }, { "cell_type": "markdown", "id": "95707004-e026-40c4-abc1-a8fa84614b6f", "metadata": {}, "source": [ "### Loading a specific split\n", "Note: only the test dataset has a subset defined (`pinder_s, pinder_xl, pinder_af2`)\n", "\n", "For train and val, you could just do:\n", "```python\n", "train_loader = PinderLoader(split=\"train\")\n", "val_loader = PinderLoader(split=\"val\")\n", "```\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "6e09ccc9-215f-4d8e-a08f-1de96bb42131", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PinderLoader(split=test, monomers=holo, systems=180)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from pinder.core import PinderLoader\n", "from pinder.core.loader import filters\n", "\n", "base_filters = [\n", " filters.FilterByMissingHolo(),\n", " filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),\n", " filters.FilterDetachedHolo(radius=12, max_components=2),\n", "]\n", "sub_filters = [\n", " filters.FilterSubByAtomTypes(min_atom_types=4),\n", " filters.FilterByHoloOverlap(min_overlap=5),\n", " filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),\n", " filters.FilterSubRmsds(rmsd_cutoff=7.5),\n", " filters.FilterDetachedSub(radius=12, max_components=2),\n", "]\n", "\n", "loader = PinderLoader(\n", " split=\"test\", \n", " subset=\"pinder_af2\",\n", " monomer_priority=\"holo\",\n", " base_filters = base_filters,\n", " sub_filters = sub_filters\n", ")\n", "\n", "loader" ] }, { "cell_type": "code", "execution_count": 7, "id": "f7bcbe6e-4bd3-435c-9ff0-cf611f6bf9cf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "180" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(loader)" ] }, { "cell_type": "code", "execution_count": 8, "id": "fecd34da-c91d-400c-bbda-64d6cccb8e0b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data is a \n" ] }, { "data": { "text/plain": [ "(pinder.core.index.system.PinderSystem,\n", " pinder.core.loader.structure.Structure,\n", " pinder.core.loader.structure.Structure)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = loader[0]\n", "print(f\"Data is a {type(data)}\")\n", "system, feature_complex, target_complex = data\n", "type(system), type(feature_complex), type(target_complex)" ] }, { "cell_type": "code", "execution_count": 9, "id": "bdb3327d-26d2-46fc-95ea-de2ea2fc0516", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 180/180 [01:14<00:00, 2.43it/s]\n" ] } ], "source": [ "# You can also use it as an iterator\n", "from tqdm import tqdm\n", "loaded_ids = []\n", "for (system, feature_complex, target_complex) in tqdm(loader):\n", " loaded_ids.append(system.entry.id)" ] }, { "cell_type": "markdown", "id": "9a1cb330-7034-4522-9763-093993a1bb4e", "metadata": {}, "source": [ "### Loading a specific list of systems\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "77d0f734-508a-4b1e-a4f3-acb8590869a0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "set()" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "systems = [\n", " \"1df0__A1_Q07009--1df0__B1_Q64537\",\n", " \"117e__A1_P00817--117e__B1_P00817\",\n", "]\n", "loader = PinderLoader(\n", " ids=systems,\n", " monomer_priority=\"holo\",\n", " base_filters = base_filters,\n", " sub_filters = sub_filters\n", ")\n", "passing_ids = []\n", "for item in loader:\n", " passing_ids.append(item[0].entry.id)\n", "\n", "systems_removed_by_filters = set(systems) - set(passing_ids)\n", "systems_removed_by_filters" ] }, { "cell_type": "code", "execution_count": 11, "id": "3644c869-da61-4620-a054-2a3ad163f3f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(systems) == len(passing_ids)" ] }, { "cell_type": "markdown", "id": "5de9d527-ffe0-4b6d-96a7-c90c137a4eb0", "metadata": {}, "source": [ "### Optional Pinder writer\n", "\n", "Without defining a writer for the `PinderLoader`, the loaded systems are available as a tuple of (`PinderSystem`, `Structure`, `Structure`) objects, containing the original `PinderSystem` and the sampled feature and target complexes, respectively. \n", "\n", "If you want to explicitly write the (potentially transformed) structure objects to a custom location or in a custom format (e.g. PDB, pickle, etc.), you can implement a subclass of `PinderWriterBase`. \n", "\n", "The default writer implements writing to PDB files (leveraging the `Structure.to_pdb` method on the structure objects). \n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "6b653c61-cebd-4ace-baf8-a27f1e011466", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[PosixPath('/var/folders/tt/x223wxwj6dzg3vjjgc_6y5bm0000gn/T/tmpbe5qjtfe/117e__A1_P00817--117e__B1_P00817/af__P00817.pdb')]\n", "[PosixPath('/var/folders/tt/x223wxwj6dzg3vjjgc_6y5bm0000gn/T/tmpbe5qjtfe/1df0__A1_Q07009--1df0__B1_Q64537/af__Q07009.pdb'), PosixPath('/var/folders/tt/x223wxwj6dzg3vjjgc_6y5bm0000gn/T/tmpbe5qjtfe/1df0__A1_Q07009--1df0__B1_Q64537/af__Q64537.pdb')]\n" ] } ], "source": [ "from pinder.core.loader.writer import PinderDefaultWriter\n", "\n", "from pathlib import Path\n", "from tempfile import TemporaryDirectory\n", "\n", "with TemporaryDirectory() as tmp_dir:\n", " temp_dir = Path(tmp_dir)\n", " loader = PinderLoader(\n", " ids=systems,\n", " monomer_priority=\"pred\",\n", " writer=PinderDefaultWriter(temp_dir)\n", " )\n", " assert set(loader.index.id) == set(systems)\n", " for i, r in loader.index.iterrows():\n", " loaded = loader[i]\n", " pinder_id = r.id\n", " system_dir = loader.writer.output_path / pinder_id\n", " assert system_dir.is_dir()\n", " print(list(system_dir.glob(\"af_*.pdb\")))\n" ] }, { "cell_type": "markdown", "id": "bda57a4e-2318-4876-8afb-04e287afceaa", "metadata": {}, "source": [ "## Constructing torch datasets and dataloaders from pinder systems\n", "\n", "The remaining sections of this tutorial will be for those interested specifically in torch datasets and dataloaders.\n", "\n", "Specifically, we will show how to:\n", "* Implement a PyTorch `Dataset` to interface with pinder data\n", "* Include apo and predicted monomers in the data pipeline, with an option to target specific monomer types or randomly sample from the available types\n", "* Leverage `PinderSystem` and its associated methods to crop apo/predicted monomers to match the ground-truth holo monomers\n", "* Write filters and transforms that operate on `Structure` objects\n", "* Integrate annotations in data filtering and featurization\n", "* Create example features to use for training (you will of course choose your own features) \n", "* Incorporate diversity sampling in the data loader \n", "\n", "\n", "The `pinder.core.loader.dataset` module provides two example implementations of how to integrate the pinder dataset into a torch-based machine learning pipeline.\n", "\n", "1. `PinderDataset`: A map-style `torch.utils.data.Dataset` that can be used with torch `DataLoader`'s.\n", "2. `PPIDataset`: A `torch_geometric.data.Dataset` that can be used with torch-geometric `DataLoader`'s. This dataset is designed to be used with the `torch_geometric` package.\n", "\n", "Together, the two datasets provide an example implementation of how to abstract away the complexity of loading and processing multiple structures associated with each `PinderSystem` by leveraging the following utilities from pinder:\n", "\n", "* `pinder.core.PinderLoader`\n", "* `pinder.core.loader.filters`\n", "* `pinder.core.loader.transforms`\n", "\n", "The examples cover two different batch data item structures to illustrate two different use-cases:\n", "\n", "* `PinderDataset`: A batch of `(target_complex, feature_complex)` pairs, where `target_complex` and `feature_complex` are `torch.Tensor` objects representing the atomic coordinates and atom types of the holo and sampled (decoy, holo/apo/pred) complexes, respectively.\n", "* `PPIDataset`: A batch of `PairedPDB` objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via `torch_geometric.data.HeteroData`, holding multiple node and/or edge types in disjunct storage objects.\n", "\n", "\n", "The remaining sections will be split into:\n", "1. Using the `PinderDataset` torch dataset\n", "2. Using the `PPIDataset` torch-geometric dataset\n", "3. How you could implement your own dataset & dataloader\n" ] }, { "cell_type": "markdown", "id": "d960b6d5-f431-4635-b6a3-8473c78f33cd", "metadata": {}, "source": [ "### PinderDataset (torch Dataset)\n", "\n", "\n", "The `PinderDataset` is an example implementation of a `torch.utils.data.Dataset` that represents its data items as a dict containing the following key, value pairs:\n", "* `target_complex`: The ground-truth holo dimer, represented with a set of default properties encoded as `Tensor`'s\n", "* `feature_complex`: The sampled dimer complex, representing \"features\", also represented with a set of default properties encoded as `Tensor`'s\n", "* `id`: The pinder ID for the selected system\n", "* `target_id`: The IDs of the receptor and ligand holo monomers, concatenated into a single ID string\n", "* `sample_id`: The IDs of the sampled receptor and ligand holo monomers, concatenated into a single ID string. This can be useful for debugging purposes or generally tracking which specific monomers are selected when targeting alternative monomers (more on this shortly)\n", "\n", "\n", "Each of the `target_complex` and `feature_complex` values are dictionaries with structural properties encoded by the `pinder.core.loader.geodata.structure2tensor` function by default:\n", "* `atom_coordinates`\n", "* `atom_types`\n", "* `residue_coordinates`\n", "* `residue_types`\n", "* `residue_ids`\n", "\n", "You can choose to use a different representation by overriding the default values of `transform` and `target_transform`.\n", " \n", "It leverages the `PinderLoader` to apply optional filters and/or transforms, provide an interface for sampling alternative monomers, and exposes `transform` and `target_transform` arguments used by the torch Dataset API. \n", "\n", "For more details on the torch Dataset APIs, please refer to the [tutorials](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#datasets-dataloaders)." ] }, { "cell_type": "code", "execution_count": 13, "id": "9658c6aa-5720-4c0f-98ae-b79d34ad2754", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pinder.core.loader import filters, transforms\n", "from pinder.core.loader.dataset import PinderDataset\n", "\n", "base_filters = [\n", " filters.FilterByMissingHolo(),\n", " filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),\n", " filters.FilterDetachedHolo(radius=12, max_components=2),\n", "]\n", "sub_filters = [\n", " filters.FilterSubByAtomTypes(min_atom_types=4),\n", " filters.FilterByHoloOverlap(min_overlap=5),\n", " filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),\n", " filters.FilterSubRmsds(rmsd_cutoff=7.5),\n", " filters.FilterDetachedSub(radius=12, max_components=2),\n", "]\n", "# We can include Structure-level transforms (and filters) which will operate on the target and feature complexes\n", "structure_transforms = [\n", " transforms.SelectAtomTypes(atom_types=[\"CA\", \"N\", \"C\", \"O\"])\n", "]\n", "train_dataset = PinderDataset(\n", " split=\"train\", \n", " # We can leverage holo, apo, pred, random and random_mixed monomer sampling strategies\n", " monomer_priority=\"random_mixed\",\n", " base_filters = base_filters,\n", " sub_filters = sub_filters,\n", " structure_transforms=structure_transforms,\n", ")\n", "assert len(train_dataset) == len(get_index().query('split == \"train\"'))\n", "\n", "train_dataset\n", "\n" ] }, { "cell_type": "markdown", "id": "beadbaca-17e4-4c8f-981d-277e702e640a", "metadata": {}, "source": [ "### Sampling alternative monomers" ] }, { "cell_type": "markdown", "id": "fa2cf4cf-58f9-4ad8-8b89-eb961f141c66", "metadata": {}, "source": [ "The `monomer_priority` argument can be used to target different mixes of bound and unbound monomers to use for creating the decoy/feature complex. \n", "\n", "The allowed values for `monomer_priority` are \"apo\", \"holo\", \"pred\", \"random\" or \"random_mixed\".\n", "\n", "When `monomer_priority` is set to one of the available monomer types (holo, apo, pred), the same monomer type will be selected for both receptor and ligand.\n", "\n", "When the monomer priority is \"random\", a random monomer type will be selected from the set of monomer types available for both the receptor and ligand. This option ensures the same type of monomer is used for the receptor and ligand.\n", "\n", "When the monomer priority is \"random_mixed\", a random monomer type will be selected for each of receptor and ligand, separately.\n", "\n", "Enabling the `fallback_to_holo` option (default) will enable silent fallback to holo when the `monomer_priority` is set to one of apo or pred, but the corresponding monomer is not available for the dimer.\n", "\n", "This is useful when only one of receptor or ligand has an unbound monomer, but you wish to include apo or predicted structures in your workflow.\n", "\n", "If `fallback_to_holo` is disabled, an error will be raised when the `monomer_priority` is set to one of apo or pred, but the corresponding monomer is not available for the dimer.\n", "\n", "\n", "By default, when apo monomers are selected, the \"canonical\" apo monomer is used. Although a single canonical apo monomer should be used for eval, pinder provides multiple apo monomers paired to a single holo monomer (when available). In order to include these non-canonical/alternative monomers, you can specify `use_canonical_apo=False` when constructing the `PinderLoader` or `PinderDataset` objects.\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "383244d7-1a85-4732-9dd4-3455fe4cec88", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'target_complex': {'atom_types': tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 1., ..., 0., 0., 0.]]),\n", " 'residue_types': tensor([[16.],\n", " [16.],\n", " [16.],\n", " ...,\n", " [ 0.],\n", " [ 0.],\n", " [ 0.]]),\n", " 'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],\n", " [132.6810, 428.2520, 163.1550],\n", " [133.5150, 428.6750, 161.9500],\n", " ...,\n", " [177.7620, 463.8650, 166.9020],\n", " [177.4130, 465.0800, 167.7550],\n", " [176.8000, 464.9490, 168.8150]]),\n", " 'residue_coordinates': tensor([[131.7500, 429.3090, 163.5360],\n", " [132.6810, 428.2520, 163.1550],\n", " [133.5150, 428.6750, 161.9500],\n", " ...,\n", " [177.7620, 463.8650, 166.9020],\n", " [177.4130, 465.0800, 167.7550],\n", " [176.8000, 464.9490, 168.8150]]),\n", " 'residue_ids': tensor([ 4., 4., 4., ..., 182., 182., 182.])},\n", " 'feature_complex': {'atom_types': tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " [1., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 1., ..., 0., 0., 0.]]),\n", " 'residue_types': tensor([[16.],\n", " [16.],\n", " [16.],\n", " ...,\n", " [ 0.],\n", " [ 0.],\n", " [ 0.]]),\n", " 'atom_coordinates': tensor([[131.7500, 429.3090, 163.5360],\n", " [132.6810, 428.2520, 163.1550],\n", " [133.5150, 428.6750, 161.9500],\n", " ...,\n", " [177.7620, 463.8650, 166.9020],\n", " [177.4130, 465.0800, 167.7550],\n", " [176.8000, 464.9490, 168.8150]]),\n", " 'residue_coordinates': tensor([[131.7500, 429.3090, 163.5360],\n", " [132.6810, 428.2520, 163.1550],\n", " [133.5150, 428.6750, 161.9500],\n", " ...,\n", " [177.7620, 463.8650, 166.9020],\n", " [177.4130, 465.0800, 167.7550],\n", " [176.8000, 464.9490, 168.8150]]),\n", " 'residue_ids': tensor([ 4., 4., 4., ..., 182., 182., 182.])},\n", " 'id': '8phr__X4_UNDEFINED--8phr__W4_UNDEFINED',\n", " 'sample_id': '8phr__X4_UNDEFINED-R--8phr__W4_UNDEFINED-L',\n", " 'target_id': '8phr__X4_UNDEFINED-R--8phr__W4_UNDEFINED-L'}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_item = train_dataset[0]\n", "data_item\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "4a6f64b4-f027-4286-b444-68cd9f6b59a4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1316, 3])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Since we used the default option of crop_equal_monomer_shapes, we should expect feature and target complex coords are identical shapes\n", "assert (\n", " data_item[\"feature_complex\"][\"atom_coordinates\"].shape\n", " == data_item[\"target_complex\"][\"atom_coordinates\"].shape\n", ")\n", "\n", "data_item[\"feature_complex\"][\"atom_coordinates\"].shape" ] }, { "cell_type": "code", "execution_count": 16, "id": "01025b2b-1962-43b3-8c4e-a33b6a3d6231", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on class PinderDataset in module pinder.core.loader.dataset:\n", "\n", "class PinderDataset(torch.utils.data.dataset.Dataset)\n", " | PinderDataset(split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = , target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = , ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'\n", " | \n", " | Method resolution order:\n", " | PinderDataset\n", " | torch.utils.data.dataset.Dataset\n", " | typing.Generic\n", " | builtins.object\n", " | \n", " | Methods defined here:\n", " | \n", " | __getitem__(self, idx: 'int') -> 'dict[str, dict[str, torch.Tensor] | torch.Tensor]'\n", " | \n", " | __init__(self, split: 'str | None' = None, index: 'pd.DataFrame | None' = None, metadata: 'pd.DataFrame | None' = None, monomer_priority: 'str' = 'holo', base_filters: 'list[PinderFilterBase]' = [], sub_filters: 'list[PinderFilterSubBase]' = [], structure_filters: 'list[StructureFilter]' = [], structure_transforms: 'list[StructureTransform]' = [], transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = , target_transform: 'Callable[[Structure], torch.Tensor | dict[str, torch.Tensor]]' = , ids: 'list[str] | None' = None, fallback_to_holo: 'bool' = True, use_canonical_apo: 'bool' = True, crop_equal_monomer_shapes: 'bool' = True, index_query: 'str | None' = None, metadata_query: 'str | None' = None, pre_specified_monomers: 'dict[str, str] | pd.DataFrame | None' = None, **kwargs: 'Any') -> 'None'\n", " | Initialize self. See help(type(self)) for accurate signature.\n", " | \n", " | __len__(self) -> 'int'\n", " | \n", " | ----------------------------------------------------------------------\n", " | Data and other attributes defined here:\n", " | \n", " | __annotations__ = {}\n", " | \n", " | __parameters__ = ()\n", " | \n", " | ----------------------------------------------------------------------\n", " | Methods inherited from torch.utils.data.dataset.Dataset:\n", " | \n", " | __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'\n", " | \n", " | ----------------------------------------------------------------------\n", " | Data descriptors inherited from torch.utils.data.dataset.Dataset:\n", " | \n", " | __dict__\n", " | dictionary for instance variables (if defined)\n", " | \n", " | __weakref__\n", " | list of weak references to the object (if defined)\n", " | \n", " | ----------------------------------------------------------------------\n", " | Data and other attributes inherited from torch.utils.data.dataset.Dataset:\n", " | \n", " | __orig_bases__ = (typing.Generic[+T_co],)\n", " | \n", " | ----------------------------------------------------------------------\n", " | Class methods inherited from typing.Generic:\n", " | \n", " | __class_getitem__(params) from builtins.type\n", " | \n", " | __init_subclass__(*args, **kwargs) from builtins.type\n", " | This method is called when a class is subclassed.\n", " | \n", " | The default implementation does nothing. It may be\n", " | overridden to extend subclasses.\n", "\n" ] } ], "source": [ "help(PinderDataset)" ] }, { "cell_type": "markdown", "id": "bea7b003-2ffc-4660-b1c8-a03cfe5086d2", "metadata": {}, "source": [ "### Torch DataLoader for PinderDataset" ] }, { "cell_type": "markdown", "id": "44316aac-dca9-4902-89c3-96de91e640cb", "metadata": {}, "source": [ "The `PinderDataset` can be served by a `torch.utils.data.DataLoader`. \n", "\n", "There is a convenience function `pinder.core.loader.dataset.get_torch_loader` for taking a `PinderDataset` and returning a `DataLoader` for the dataset object. \n", "\n", "We can leverage the default `collate_fn` (`pinder.core.loader.dataset.collate_batch`) to merge multiple systems (`Dataset` items) to create mini-batches of tensors:\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "fcb50bed-7e0a-4d4b-a1d7-6bc003161aab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-09-05 14:29:47,942 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=5, items=5\n", "2024-09-05 14:29:49,311 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 1.37s\n", "2024-09-05 14:29:49,381 | pinder.core.loader.structure:595 | ERROR : no common residues found! 2zu1__A1_P03313--7vy5__C41_P03313-L\n", "2024-09-05 14:29:49,636 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=7, items=7\n", "2024-09-05 14:29:49,936 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 0.30s\n" ] } ], "source": [ "from pinder.core.loader.dataset import collate_batch, get_torch_loader\n", "from torch.utils.data import DataLoader\n", "\n", "batch_size = 2\n", "train_dataloader = get_torch_loader(\n", " train_dataset, \n", " batch_size=batch_size,\n", " shuffle=True,\n", " collate_fn=collate_batch,\n", " num_workers=0, \n", ")\n", "assert isinstance(train_dataloader, DataLoader)\n", "assert hasattr(train_dataloader, \"dataset\")\n", "\n", "# Get a batch from the dataloader\n", "batch = next(iter(train_dataloader))\n", "\n", "# expected batch dict keys\n", "assert set(batch.keys()) == {\n", " \"target_complex\",\n", " \"feature_complex\",\n", " \"id\",\n", " \"sample_id\",\n", " \"target_id\",\n", "}\n", "assert isinstance(batch[\"target_complex\"], dict)\n", "assert isinstance(batch[\"target_complex\"][\"atom_coordinates\"], torch.Tensor)\n", "feature_coords = batch[\"feature_complex\"][\"atom_coordinates\"]\n", "# Ensure batch size propagates to tensor dims\n", "assert feature_coords.shape[0] == batch_size\n", "# Ensure coordinates have dim 3\n", "assert feature_coords.shape[2] == 3\n" ] }, { "cell_type": "markdown", "id": "66aa43ac-b347-4ee1-8b13-3a8540b088ee", "metadata": {}, "source": [ "### Torch geometric Dataset \n" ] }, { "cell_type": "code", "execution_count": 9, "id": "b0fe7662-674c-4af2-b16f-b38082377335", "metadata": {}, "outputs": [], "source": [ "# Make sure to install torch_cluster\n", "# !pip install torch_cluster" ] }, { "cell_type": "code", "execution_count": 18, "id": "0de517fc-6e7f-4682-af48-980d27638b21", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing...\n", "2024-09-05 14:29:56,429 | pinder.core.loader.dataset:533 | INFO : Finished processing, only 5 systems\n", "Done!\n" ] }, { "data": { "text/plain": [ "PPIDataset(5)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pinder.core.loader.dataset import PPIDataset\n", "from pinder.core.loader.geodata import NodeRepresentation\n", "\n", "nodes = {NodeRepresentation(\"atom\"), NodeRepresentation(\"residue\")}\n", "\n", "train_dataset = PPIDataset(\n", " node_types=nodes,\n", " split=\"train\",\n", " monomer1=\"holo_receptor\",\n", " monomer2=\"holo_ligand\",\n", " limit_by=5,\n", " force_reload=True,\n", " parallel=False,\n", ")\n", "assert len(train_dataset) == 5\n", "\n", "train_dataset\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "5392a214-7426-4be2-bc81-a00702c82c55", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PairedPDB(\n", " ligand_residue={\n", " residueid=[158, 1],\n", " pos=[158, 3],\n", " edge_index=[2, 1580],\n", " chain=[1],\n", " },\n", " receptor_residue={\n", " residueid=[171, 1],\n", " pos=[171, 3],\n", " edge_index=[2, 1710],\n", " chain=[1],\n", " },\n", " ligand_atom={\n", " x=[1198, 12],\n", " pos=[1198, 3],\n", " edge_index=[2, 11980],\n", " },\n", " receptor_atom={\n", " x=[1358, 12],\n", " pos=[1358, 3],\n", " edge_index=[2, 13580],\n", " },\n", " pdb={\n", " id=[1],\n", " num_nodes=1,\n", " }\n", ")" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torch_geometric.data import HeteroData\n", "from pinder.core import get_index\n", "\n", "# Here we check that we correctly capped the number of systems to load to 5 (via limit_by)\n", "pindex = get_index()\n", "raw_ids = set(train_dataset.raw_file_names)\n", "assert len(raw_ids.intersection(set(pindex.id))) == 5\n", "processed_ids = {f.stem for f in train_dataset.processed_file_names}\n", "# Here we ensure that all 5 ids got processed and saved as .pt file on disk\n", "assert len(processed_ids.intersection(set(pindex.id))) == 5\n", "\n", "# Let's get an item from the dataset by index \n", "data_item = train_dataset[0]\n", "assert isinstance(data_item, HeteroData)\n", "data_item\n", "\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "6fa927a9-cde9-4f32-acf9-f5c989b32966", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PairedPDB(\n", " ligand_residue={\n", " residueid=[158, 1],\n", " pos=[158, 3],\n", " edge_index=[2, 1580],\n", " chain=[1],\n", " },\n", " receptor_residue={\n", " residueid=[171, 1],\n", " pos=[171, 3],\n", " edge_index=[2, 1710],\n", " chain=[1],\n", " },\n", " ligand_atom={\n", " x=[1198, 12],\n", " pos=[1198, 3],\n", " edge_index=[2, 11980],\n", " },\n", " receptor_atom={\n", " x=[1358, 12],\n", " pos=[1358, 3],\n", " edge_index=[2, 13580],\n", " },\n", " pdb={\n", " id=[1],\n", " num_nodes=1,\n", " }\n", ")" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We can also get an item by its system ID \n", "data_item = train_dataset.get_filename(\"8phr__X4_UNDEFINED--8phr__W4_UNDEFINED\")\n", "data_item\n", "\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "75e13a76-cf9f-47dd-90c2-370c0652074f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PPIDataset (#graphs=5):\n", "+------------+----------+----------+\n", "| | #nodes | #edges |\n", "|------------+----------+----------|\n", "| mean | 2612.4 | 0 |\n", "| std | 1631.6 | 0 |\n", "| min | 999 | 0 |\n", "| quantile25 | 1600 | 0 |\n", "| median | 2343 | 0 |\n", "| quantile75 | 2886 | 0 |\n", "| max | 5234 | 0 |\n", "+------------+----------+----------+\n", "Number of nodes per node type:\n", "+------------+------------------+--------------------+---------------+-----------------+-------+\n", "| | ligand_residue | receptor_residue | ligand_atom | receptor_atom | pdb |\n", "|------------+------------------+--------------------+---------------+-----------------+-------|\n", "| mean | 140.2 | 152.8 | 1103.4 | 1215 | 1 |\n", "| std | 90.2 | 87.5 | 738.9 | 717.1 | 0 |\n", "| min | 58 | 62 | 426 | 452 | 1 |\n", "| quantile25 | 78 | 97 | 622 | 802 | 1 |\n", "| median | 121 | 144 | 958 | 1119 | 1 |\n", "| quantile75 | 158 | 171 | 1198 | 1358 | 1 |\n", "| max | 286 | 290 | 2313 | 2344 | 1 |\n", "+------------+------------------+--------------------+---------------+-----------------+-------+\n" ] } ], "source": [ "train_dataset.print_summary()\n" ] }, { "cell_type": "markdown", "id": "d9f3308e-4711-4424-b5c3-afa0d0fcccce", "metadata": {}, "source": [ "### PairedPDB torch-geometric HeteroData object\n", "\n", "The `PPIDataset` represents its data items as `PairedPDB` objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via `torch_geometric.data.HeteroData`, holding multiple node and/or edge types in disjunct storage objects.\n", "\n", "It leverages the `PinderLoader` to apply optional filters and/or transforms and implements a caching system by saving processed systems to disk in `.pt` files.\n", "The `PairedPDB` implements a conversion method `.from_pinder_system` that takes in a `PinderSystem` and converts it to a `PairedPDB` object. \n", "\n", "For more details on the torch-geometric Dataset APIs, please refer to the [tutorials](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html)." ] }, { "cell_type": "code", "execution_count": 22, "id": "ce872929-1706-4a09-ad67-63457511099f", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "PairedPDB(\n", " ligand_residue={\n", " residueid=[127, 1],\n", " pos=[127, 3],\n", " edge_index=[2, 1270],\n", " chain=[1],\n", " },\n", " receptor_residue={\n", " residueid=[180, 1],\n", " pos=[180, 3],\n", " edge_index=[2, 1800],\n", " chain=[1],\n", " },\n", " ligand_atom={\n", " x=[1032, 12],\n", " pos=[1032, 3],\n", " edge_index=[2, 10320],\n", " },\n", " receptor_atom={\n", " x=[1441, 12],\n", " pos=[1441, 3],\n", " edge_index=[2, 14410],\n", " }\n", ")" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pinder.core.loader.geodata import PairedPDB\n", "from torch_geometric.data import HeteroData\n", "\n", "pinder_id = \"3s9d__B1_P48551--3s9d__A1_P01563\"\n", "system = PinderSystem(pinder_id)\n", "\n", "holo_data = PairedPDB.from_pinder_system(\n", " system=system,\n", " monomer1=\"holo_receptor\", monomer2=\"holo_ligand\",\n", " node_types=nodes,\n", ")\n", "assert isinstance(holo_data, HeteroData)\n", "expected_node_types = [\n", " 'ligand_residue', 'receptor_residue', 'ligand_atom', 'receptor_atom'\n", "]\n", "assert holo_data.num_nodes == 2780\n", "assert holo_data.num_edges == 0\n", "assert isinstance(holo_data.num_node_features, dict)\n", "expected_num_feats = {\n", " 'ligand_residue': 0,\n", " 'receptor_residue': 0,\n", " 'ligand_atom': 12,\n", " 'receptor_atom': 12\n", "}\n", "for k, v in expected_num_feats.items():\n", " assert holo_data.num_node_features[k] == v\n", "\n", "assert holo_data.node_types == expected_node_types\n", "\n", "\n", "holo_data\n", "\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "b2913a8f-2992-46de-ab74-1fca594d6088", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PairedPDB(\n", " ligand_residue={\n", " residueid=[165, 1],\n", " pos=[165, 3],\n", " edge_index=[2, 1650],\n", " chain=[1],\n", " },\n", " receptor_residue={\n", " residueid=[212, 1],\n", " pos=[212, 3],\n", " edge_index=[2, 2120],\n", " chain=[1],\n", " },\n", " ligand_atom={\n", " x=[1350, 12],\n", " pos=[1350, 3],\n", " edge_index=[2, 13500],\n", " },\n", " receptor_atom={\n", " x=[1710, 12],\n", " pos=[1710, 3],\n", " edge_index=[2, 17100],\n", " }\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# You can target specific monomers (apo/holo/pred) for the receptor and ligand\n", "apo_data = PairedPDB.from_pinder_system(\n", " system=system,\n", " monomer1=\"apo_receptor\", monomer2=\"apo_ligand\",\n", " node_types=nodes,\n", ")\n", "assert isinstance(apo_data, HeteroData)\n", "\n", "assert apo_data.num_nodes == 3437\n", "assert apo_data.num_edges == 0\n", "assert isinstance(apo_data.num_node_features, dict)\n", "expected_num_feats = {\n", " 'ligand_residue': 0,\n", " 'receptor_residue': 0,\n", " 'ligand_atom': 12,\n", " 'receptor_atom': 12\n", "}\n", "for k, v in expected_num_feats.items():\n", " assert apo_data.num_node_features[k] == v\n", "\n", "assert apo_data.node_types == expected_node_types\n", "\n", "apo_data" ] }, { "cell_type": "markdown", "id": "bd294fb0-bc53-45b6-a169-33b0e8460632", "metadata": {}, "source": [ "### Torch geometric DataLoader \n", "\n", "The `PPIDataset` can be served by a `torch_geometric.DataLoader`. \n", "\n", "There is a convenience function `pinder.core.loader.dataset.get_geo_loader` for taking a PPIDataset and returning a `DataLoader` for the dataset object. \n", "\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "29ca7612-8a1b-4cf7-a790-cc2b5b26e7e5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pinder.core.loader.dataset import get_geo_loader\n", "from torch_geometric.loader import DataLoader\n", "\n", "\n", "loader = get_geo_loader(train_dataset)\n", "\n", "assert isinstance(loader, DataLoader)\n", "assert hasattr(loader, \"dataset\")\n", "ds = loader.dataset\n", "assert len(ds) == 5\n", "\n", "\n", "loader" ] }, { "cell_type": "code", "execution_count": 25, "id": "66f9a0ef-1c1f-47bd-ba86-f2c3340c34ec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PPIDataset(5)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds" ] }, { "cell_type": "markdown", "id": "b6a13541-0eb5-46b5-af6a-9c17cf94e6bf", "metadata": {}, "source": [ "## Implementing your own PyTorch Dataset & DataLoader for pinder\n", "\n", "\n", "While the previous sections covered two example implementations of leveraging the torch and torch-geometric APIs, below we will illustrate how you can write your own PyTorch data pipeline to feed your model. \n", "\n", "We will focus on writing a minimalistic example that simply fetches PDB files and returns the coordinates as tensors. We will also include an example of a sampler function for the dataloader to implement diversity sampling in your workflow.\n", "\n" ] }, { "cell_type": "markdown", "id": "28e6a8c0-30ae-4c7b-8f6b-c3fae5560032", "metadata": {}, "source": [ "### Defining the Dataset\n", "\n", "Below we will write a barebones `torch.utils.data.Dataset` object that implements at a minimum:\n", "* `__init__` method \n", "* `__len__` method returning the number of items in the dataset\n", "* `__getitem__` method that returns an item in the dataset\n", "\n", "We will also add an option to include apo and predicted monomer types by adding a `monomer_priority` argument to our dataset. The argument will be set to \"random\" by default, indicating that we want to randomly sample a monomer type from the set of monomers available for a given system. We could also target a specific monomer type by setting this argument to one of \"holo\", \"apo\", \"predicted\". Not every system in the training set has apo or predicted structures available, so we will also add a `fallback_to_holo` argument to indicate whether we want to use `holo` monomers when the selected monomer type is not available. \n", "\n", "We also define two interfaces for applying filters to our dataset:\n", "1. `metadata_filter`: a query string to apply to the pinder metadata pandas DataFrame\n", "2. `system_filters: list[PinderFilterBase]`: a list of filters that inheret a base class, `PinderFilterBase`, which serves as the abstraction layer for defining `PinderSystem`-based filters\n", "3. `structure_filters: list[StructureFilter]`: a list of filters that inheret a base class, `StructureFilter`, which serves as the abstraction layer for defining `Structure`-based filters\n", "\n" ] }, { "cell_type": "code", "execution_count": 26, "id": "bb468005-5ffb-45e8-9f03-f073cb6f7527", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from numpy.typing import NDArray\n", "from torch.utils.data import Dataset\n", "from pinder.core import get_index, get_metadata\n", "from pinder.core.loader.loader import _create_target_feature_complex, select_monomer\n", "from pinder.core.loader.structure import Structure\n", "\n", "index = get_index()\n", "metadata = get_metadata()\n", "\n", "\n", "class CustomPinderDataset(Dataset):\n", " def __init__(\n", " self, \n", " split: str, \n", " monomer_priority: str = \"random\", \n", " fallback_to_holo: bool = True, \n", " crop_equal_monomer_shapes: bool = True,\n", " use_canonical_apo: bool = True,\n", " transform=None, \n", " target_transform=None,\n", " structure_filters: list[filters.StructureFilter] = [],\n", " system_filters: list[filters.PinderFilterBase] = [],\n", " metadata_filter: str | None = None,\n", " max_load_attempts: int = 10,\n", " ) -> None:\n", " # Which split/mode we are using for the dataset instance\n", " self.split = split\n", " self.monomer_priority = monomer_priority\n", " self.fallback_to_holo = fallback_to_holo\n", " self.crop_equal_monomer_shapes = crop_equal_monomer_shapes\n", "\n", " # Optional transform and target transform to apply (will be covered shortly)\n", " self.transform = transform\n", " self.target_transform = target_transform\n", " # Optional system-level filters to apply\n", " self.system_filters = system_filters\n", " # Optional structure filters to apply\n", " self.structure_filters = structure_filters\n", "\n", " # Maximum number of times to try sampling another index from the dataset until an exception is raised\n", " self.max_load_attempts = max_load_attempts\n", " # Whether we should use canonical apo structures (apo_R/L_pdb columns in pinder index) if apo monomers are selected\n", " self.use_canonical_apo = use_canonical_apo\n", " \n", " # Define the subset of the pinder index and metadata corresponding to the split of our dataset instance \n", " self.index = index.query(f'split == \"{split}\"').reset_index(drop=True)\n", " self.metadata = metadata[metadata[\"id\"].isin(set(self.index.id))].reset_index(drop=True)\n", " if metadata_filter:\n", " try:\n", " self.metadata = self.metadata.query(metadata_filter).reset_index(drop=True)\n", " except Exception as e:\n", " print(f\"Failed to apply metadata_filter={metadata_filter}: {e}\")\n", " \n", " self.index = self.index[self.index[\"id\"].isin(set(self.metadata.id))].reset_index(drop=True)\n", "\n", " def __len__(self):\n", " return len(self.index)\n", " \n", " def __getitem__(self, idx: int) -> tuple[NDArray[np.double], NDArray[np.double]]:\n", " valid_idx = False\n", " attempts = 0\n", " while not valid_idx and attempts < self.max_load_attempts:\n", " attempts += 1\n", " row = self.index.iloc[idx]\n", " system = PinderSystem(row.id)\n", " \n", " system = self.apply_system_filters(system)\n", " if not isinstance(system, PinderSystem):\n", " continue\n", "\n", " selected_monomers = select_monomer(\n", " row,\n", " self.monomer_priority,\n", " self.fallback_to_holo,\n", " self.use_canonical_apo,\n", " )\n", " # With the system and selected_monomers objects, we can now create a pair of dimer complexes\n", " # Below we leverage the existing utility from the PinderLoader (_create_target_feature_complex)\n", " target_complex, feature_complex = _create_target_feature_complex(\n", " system, selected_monomers, self.crop_equal_monomer_shapes, self.fallback_to_holo\n", " )\n", " valid_idx = self.apply_structure_filters(target_complex)\n", " if not valid_idx:\n", " # Try another index before raising IndexError\n", " idx = random.choice(list(range(len(self))))\n", "\n", " if not valid_idx:\n", " raise IndexError(\n", " f\"Unable to find a valid item in the dataset satisfying filters at {idx} after {attempts} attempts!\"\n", " )\n", " if self.transform is not None:\n", " feature_complex = self.transform(feature_complex)\n", " if self.target_transform is not None:\n", " target_complex = self.target_transform(target_complex)\n", " return feature_complex, target_complex\n", "\n", " def apply_structure_filters(self, structure: Structure) -> bool:\n", " pass_filters = True\n", " for structure_filter in self.structure_filters:\n", " if not structure_filter(structure):\n", " pass_filters = False\n", " break\n", " return pass_filters\n", "\n", " def apply_system_filters(self, system: PinderSystem) -> PinderSystem | bool:\n", " for system_filter in self.system_filters:\n", " if isinstance(system_filter, filters.PinderFilterBase):\n", " if not base_filter(system):\n", " return False\n", " return system\n", "\n", " def __repr__(self) -> str:\n", " return f\"CustomPinderDataset(split={self.split}, monomers={self.monomer_priority}, systems={len(self)})\"\n", "\n" ] }, { "cell_type": "code", "execution_count": 27, "id": "def48617-4288-452f-af9e-55f3289d7788", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/pdbs/af__A0A229LVN5--af__A0A229LVN5.pdb,\n", " uniprot_map=None,\n", " pinder_id='af__A0A229LVN5--af__A0A229LVN5',\n", " atom_array= with shape (2092,),\n", " pdb_engine='fastpdb',\n", " ),\n", " Structure(\n", " filepath=/Users/danielkovtun/.local/share/pinder/2024-02/test_set_pdbs/7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L.pdb,\n", " uniprot_map= with shape (294, 14),\n", " pinder_id='7rzb__A1_A0A229LVN5-R--7rzb__A2_A0A229LVN5-L',\n", " atom_array= with shape (2092,),\n", " pdb_engine='fastpdb',\n", " ))" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Note the selected monomers indicated by Structure.pinder_id attributes. Since we enabled cropping, the feature and target complex AtomArray have identical shapes \n", "test_data = CustomPinderDataset(split=\"test\")\n", "test_data[0]" ] }, { "cell_type": "markdown", "id": "89ed5d5a-0502-4ba8-a048-0da86aca989a", "metadata": {}, "source": [ "In the above example, we wrote a torch Dataset that currently returns a tuple of `Structure` objects (one representing the \"decoy\" or sample to use for features and the other representing the ground-truth \"label\")\n", "\n", "Of course we wouldn't use these `Structure` objects directly in our model. In your model, you will have to choose which features to compute and which data structure to use to represent them. \n", "\n", "While the `PinderDataset` and `PPIDataset` datasets provide examples of feature encodings, below we will adjust the `transform` and `target_transform` objects to simply return NumPy NDArray objects as a default. Note: you can't pass `Structure` objects to torch DataLoader, you must first convert them into array-like structures supported by the default torch collate_fn. You can implement your own collate functions to adjust this behavior." ] }, { "cell_type": "code", "execution_count": 28, "id": "c13a072b-ca2b-47bb-9ed1-2a88885aad93", "metadata": {}, "outputs": [], "source": [ "def default_transform(structure: Structure) -> NDArray[np.double]:\n", " return structure.coords" ] }, { "cell_type": "code", "execution_count": 29, "id": "84f1a37c-4ac5-4086-bfc8-260bc8f2de8d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[-12.6210985, -9.128864 , 17.258345 ],\n", " [-13.660538 , -8.840154 , 16.265753 ],\n", " [-13.38884 , -7.5286074, 15.524057 ],\n", " ...,\n", " [ 7.1754823, -19.776093 , 21.191608 ],\n", " [ 9.592421 , -19.601936 , 19.197937 ],\n", " [ 9.432583 , -17.757698 , 20.458267 ]], dtype=float32),\n", " array([[-13.215324 , -11.076905 , 15.214827 ],\n", " [-14.133494 , -10.301853 , 14.386162 ],\n", " [-13.614427 , -8.882867 , 14.150835 ],\n", " ...,\n", " [ 6.8049664, -15.96424 , 20.159506 ],\n", " [ 9.545281 , -16.992254 , 18.400843 ],\n", " [ 7.0213795, -15.329907 , 18.152586 ]], dtype=float32))" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data = CustomPinderDataset(split=\"test\", transform=default_transform, target_transform=default_transform)\n", "test_data[0]" ] }, { "cell_type": "markdown", "id": "f24f2498-da37-4518-b1c8-2c06652767ac", "metadata": {}, "source": [ "### Implementing diversity sampling\n", "\n", "Here, we provide an example of how one might use `torch.utils.data.WeightedRandomSampler`. However, users are free to sample diversity any way they see fit. For this example, we are going to sample diversity inversely proportional to pinder cluster population size. \n" ] }, { "cell_type": "code", "execution_count": 30, "id": "8c495bfe-3aa8-4544-b94e-3bcb6e149dba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torch.utils.data import WeightedRandomSampler\n", "\n", "def inverse_cluster_size_sampler(dataset: PinderDataset, replacement: bool = True):\n", " index = dataset.index\n", " cluster_counts = (\n", " index[\"cluster_id\"].value_counts().rename(\"cluster_count\")\n", " )\n", " index = index.merge(\n", " cluster_counts, left_on=\"cluster_id\", right_index=True\n", " )\n", " # undersample large clusters\n", " cluster_weights = 1.0 / torch.tensor(index.cluster_count.values)\n", " return WeightedRandomSampler(\n", " weights=cluster_weights,\n", " num_samples=len(\n", " cluster_counts\n", " ),\n", " replacement=replacement,\n", " )\n", "\n", "sampler = inverse_cluster_size_sampler(\n", " test_data,\n", " replacement=True,\n", ")\n", "sampler\n" ] }, { "cell_type": "markdown", "id": "6fffb0cb-53b7-452a-b724-65935d8e1f9f", "metadata": {}, "source": [ "### Defining the dataloader \n", "\n", "Now that we have implemented a dataset and sampling function, we can tie everything together to implement the `DataLoader`.\n" ] }, { "cell_type": "code", "execution_count": 31, "id": "cc7f7dc6-f01b-4d52-a4b5-e0dd19d68fe4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[[-2.5430e+00, 1.7731e+01, -2.8000e-02],\n", " [-2.4430e+00, 1.6607e+01, 9.4600e-01],\n", " [-2.3810e+00, 1.7162e+01, 2.3870e+00],\n", " ...,\n", " [-3.8556e+01, -1.1730e+00, -1.2576e+01],\n", " [-3.8980e+01, -7.6000e-02, -1.2154e+01],\n", " [-3.9255e+01, -1.9450e+00, -1.3256e+01]]]),\n", " tensor([[[ 12.3608, -5.7149, 7.7741],\n", " [ 12.8355, -5.7361, 9.1869],\n", " [ 12.1397, -6.8793, 9.9598],\n", " ...,\n", " [ 29.1573, -15.5614, 2.8228],\n", " [ 29.4802, -14.5300, 2.1956],\n", " [ 27.9742, -15.8473, 3.0789]]]))" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torch.utils.data import DataLoader\n", "\n", "test_dataloader = DataLoader(\n", " test_data, \n", " batch_size=1, \n", " # Mutually exclusive with sampler\n", " shuffle=False, \n", " sampler=sampler,\n", ")\n", "test_features, test_labels = next(iter(test_dataloader))\n", "test_features, test_labels\n", "\n" ] }, { "cell_type": "markdown", "id": "3436d5a6-6020-4316-a0c0-af10f8831b7b", "metadata": {}, "source": [ "Putting it all together, we can now get a train/val/test dataloader as such:\n", "\n" ] }, { "cell_type": "code", "execution_count": 32, "id": "fe172bdd-09eb-45ed-bc51-28202e5158b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from typing import Any, Callable\n", "\n", "\n", "def get_loader(\n", " dataset: CustomPinderDataset,\n", " sampler: torch.utils.data.Sampler | None = inverse_cluster_size_sampler,\n", " batch_size: int = 2,\n", " # shuffle is mutually exclusive with sampler\n", " shuffle: bool = False,\n", " num_workers: int = 0,\n", " collate_fn: Callable[[list[tuple[NDArray[np.double], NDArray[np.double]]]], tuple[torch.Tensor, torch.Tensor]] | None = None,\n", " **kwargs: Any,\n", ") -> \"DataLoader[CustomPinderDataset]\":\n", " return DataLoader(\n", " dataset,\n", " batch_size=batch_size,\n", " shuffle=shuffle,\n", " num_workers=num_workers,\n", " sampler=sampler,\n", " collate_fn=collate_fn,\n", " **kwargs,\n", " )\n", "\n", "\n", "train_data = CustomPinderDataset(\n", " split=\"train\", \n", " structure_filters=[filters.MinAtomTypesFilter()], \n", " metadata_filter=\"(buried_sasa >= 500)\",\n", " transform=default_transform, \n", " target_transform=default_transform,\n", ")\n", "train_dataloader = get_loader(\n", " train_data, \n", " sampler=inverse_cluster_size_sampler(train_data),\n", " batch_size=1,\n", ")\n", "train_dataloader\n" ] }, { "cell_type": "code", "execution_count": 33, "id": "b7602385-2893-4c09-afc3-6a10433c7ada", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-09-05 14:30:01,983 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=7, items=7\n", "2024-09-05 14:30:03,261 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 1.28s\n" ] }, { "data": { "text/plain": [ "(tensor([[[125.8500, 300.1720, 233.1940],\n", " [124.5070, 299.9340, 232.7260],\n", " [124.1540, 301.1870, 231.9160],\n", " ...,\n", " [113.6850, 300.0020, 243.7090],\n", " [114.5000, 301.0400, 243.7710],\n", " [119.5730, 297.7760, 244.9850]]]),\n", " tensor([[[125.8500, 300.1720, 233.1940],\n", " [124.5070, 299.9340, 232.7260],\n", " [124.1540, 301.1870, 231.9160],\n", " ...,\n", " [113.6850, 300.0020, 243.7090],\n", " [114.5000, 301.0400, 243.7710],\n", " [119.5730, 297.7760, 244.9850]]]))" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_features, train_labels = next(iter(train_dataloader))\n", "train_features, train_labels\n" ] }, { "cell_type": "markdown", "id": "323f286e-b81c-4863-9902-e9bc5f4b9b28", "metadata": {}, "source": [ "### Using a larger batch size (collate_fn)\n", "What if we want to use a larger batch size?\n", "\n", "By default, the `collate_fn` used by the DataLoader is `torch.utils.data._utils.collate.default_collate` which expects to be able to stack tensors via `torch.stack`.\n", "\n", "Since different pinder systems have a differing number of atomic coordinates, using a batch size greater than 1 will cause this function to raise a RuntimeError with a message like:\n", "`RuntimeError: stack expects each tensor to be equal size, but got [688, 3] at entry 0 and [1391, 3] at entry 1`\n", "\n", "\n", "We can leverage existing pinder utilities to adapt our default collate_fn to pad tensor dimensions with dummy values so that they can be stacked. \n", "\n" ] }, { "cell_type": "code", "execution_count": 34, "id": "cce6f303-99e7-4d42-805d-72f39d9c61b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on function pad_and_stack in module pinder.core.loader.dataset:\n", "\n", "pad_and_stack(tensors: 'list[Tensor]', dim: 'int' = 0, dims_to_pad: 'list[int] | None' = None, value: 'int | float | None' = None) -> 'Tensor'\n", " Pads a list of tensors to the maximum length observed along each dimension and then stacks them along a new dimension (given by `dim`).\n", " \n", " Parameters:\n", " tensors (list[Tensor]): A list of tensors to pad and stack\n", " dim (int): The new dimension to stack along.\n", " dims_to_pad (list[int] | None): The dimensions to pad\n", " value (int | float | None, optional): The value to pad with, by default None\n", " \n", " Returns:\n", " Tensor: The padded and stacked tensor. Below are examples of input and output shapes\n", " Example 1: Sequence features (although redundant with torch.rnn.utils.pad_sequence)\n", " input: [(2,), (7,)], dim: 0\n", " output: (2, 7)\n", " Example 2: Pair features (e.g., pairwise coordinates)\n", " input: [(4, 4, 3), (7, 7, 3)], dim: 0\n", " output: (2, 7, 7, 3)\n", "\n" ] } ], "source": [ "from pinder.core.loader.dataset import pad_and_stack\n", "\n", "help(pad_and_stack)\n" ] }, { "cell_type": "code", "execution_count": 35, "id": "4fbf640e-3348-439c-a8d0-ae4bb2c86dc3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-09-05 14:30:03,504 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=7, items=7\n", "2024-09-05 14:30:04,679 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 1.17s\n", "2024-09-05 14:30:07,053 | pinder.core.utils.cloud:375 | INFO : Gsutil process_many=download_to_filename, threads=6, items=6\n", "2024-09-05 14:30:07,242 | pinder.core.utils.cloud.process_many:23 | INFO : runtime succeeded: 0.19s\n" ] }, { "data": { "text/plain": [ "(tensor([[[ 173.7118, 222.4612, 146.3869],\n", " [ 173.7950, 221.7822, 147.6740],\n", " [ 175.1321, 221.0380, 147.7151],\n", " ...,\n", " [ 162.2876, 298.3475, 158.2268],\n", " [ 163.3745, 297.5005, 157.0505],\n", " [ 164.2955, 296.3959, 158.1551]],\n", " \n", " [[ 274.2440, 238.6180, 254.4800],\n", " [ 275.1950, 238.5360, 253.3790],\n", " [ 274.5000, 238.5590, 252.0220],\n", " ...,\n", " [-100.0000, -100.0000, -100.0000],\n", " [-100.0000, -100.0000, -100.0000],\n", " [-100.0000, -100.0000, -100.0000]]]),\n", " tensor([[[ 179.8020, 204.6620, 163.5760],\n", " [ 179.3740, 203.2420, 163.3880],\n", " [ 180.6090, 202.3680, 163.1150],\n", " ...,\n", " [ 185.2960, 297.5130, 155.7020],\n", " [ 183.4960, 297.2930, 155.6730],\n", " [ 182.9220, 298.8890, 156.2510]],\n", " \n", " [[ 274.2440, 238.6180, 254.4800],\n", " [ 275.1950, 238.5360, 253.3790],\n", " [ 274.5000, 238.5590, 252.0220],\n", " ...,\n", " [-100.0000, -100.0000, -100.0000],\n", " [-100.0000, -100.0000, -100.0000],\n", " [-100.0000, -100.0000, -100.0000]]]))" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def collate_coordinates(batch, coords_pad_value: int = -100):\n", " feature_coords = []\n", " target_coords = []\n", " for x in batch:\n", " feat, target = x\n", " if isinstance(feat, np.ndarray):\n", " feat = torch.tensor(feat, dtype=torch.float32)\n", " if isinstance(target, np.ndarray):\n", " target = torch.tensor(target, dtype=torch.float32)\n", " feature_coords.append(feat)\n", " target_coords.append(target)\n", "\n", " feature_coords = pad_and_stack(feature_coords, dim=0, value=coords_pad_value) \n", " target_coords = pad_and_stack(target_coords, dim=0, value=coords_pad_value) \n", " return feature_coords, target_coords\n", "\n", "\n", "train_dataloader = get_loader(\n", " train_data, \n", " sampler=inverse_cluster_size_sampler(train_data),\n", " collate_fn=collate_coordinates,\n", " batch_size=2,\n", ")\n", "train_features, train_labels = next(iter(train_dataloader))\n", "train_features, train_labels\n" ] } ], "metadata": { "kernelspec": { "display_name": "pinder", "language": "python", "name": "pinder" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }