Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProteinGraphDataset for pairs of proteins? #224

Open
kamurani opened this issue Oct 30, 2022 · 3 comments
Open

ProteinGraphDataset for pairs of proteins? #224

kamurani opened this issue Oct 30, 2022 · 3 comments
Labels
1 - Priority P1 High Priority enhancement New feature or request ML

Comments

@kamurani
Copy link
Contributor

I was wondering if there's an elegant (easy) way to use the inbuilt ProteinGraphDataset class for making a PyTorch dataloader that can supply proteins in pairs, with an associated label.

Many use cases of GNNs involve predicting some kind of interaction or behaviour that involves 2 or more proteins (e.g. interaction or binding affinity: binary classification label). At the moment I can only see a way to supply 1:1 graph labels per protein ID.

Has anyone worked with the graphein dataset classes for this use case? If not I would be happy to try modifying the dataloader class to allow this as an option, although I would greatly appreciate some pointers on how to best do this.

Thanks everyone! And sorry if there's an easy way to do this that i've simply missed in the docs.

Cheers

@a-r-j
Copy link
Owner

a-r-j commented Oct 30, 2022

Hey @kamurani great idea!

This would be a great feature - is this something you'd like to work on?

If you're just looking to get up and running I threw together this (untested) solution if you're happy to take care of the processing to Data and pairing yourself.

import torch
from torch_geometric.data import  InMemoryDataset
from typing import List, Tuple, Optional, Any

def pair_data(a: Data, b: Data) -> Data:
    """Pairs two graphs together in a single ``Data`` instance.

    The first graph is accessed via ``data.a`` (e.g. ``data.a.coords``) and the second via ``data.b``.
    
    :param a: The first graph.
    :type a: torch_geometric.data.Data
    :param b: The second graph.
    :type b: torch_geometric.data.Data
    :return: The paired graph.
    """
    out = Data()
    out.a = a
    out.b = b
    return out

class PairedProteinGraphListDataset(InMemoryDataset):
    def __init__(
        self, root: str, data_list: List[Tuple[Data, Data]], name: str, labels: Optional[Any] = None, transform=None
    ):
        """Creates a dataset from a list of PyTorch Geometric Data objects.
        :param root: Root directory where the dataset is stored.
        :type root: str
        :param data_list: List of protein graphs as PyTorch Geometric Data
            objects.
        :type data_list: List[Data]
        :param name: Name of dataset. Data will be saved as ``data_{name}.pt``.
        :type name: str
        :param transform: A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        :type transform: Optional[Callable], optional
        """
        self.data_list = data_list
        self.name = name
        self.labels = labels
        super().__init__(root, transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        """The name of the files in the :obj:`self.processed_dir` folder that
        must be present in order to skip processing."""
        return f"data_{self.name}.pt"

    def process(self):
        """Saves data files to disk."""
        # Pair data objects
        paired_data = [pair_data(a, b) for a, b in self.data_list]
        
        # Assign labels
        if self.labels is not None:
            for i, d in enumerate(paired_data):
                d.y = self.labels[i]

        torch.save(self.collate(paired_data), self.processed_paths[0])

@kamurani
Copy link
Contributor Author

Legend, thank you! I would be happy to implement this for graphein in a similar way as an extension to the other DataSet classes. My use case also involves a particular node of interest (particular amino acid residue in the protein) being specified for one or both of the graphs, which might be useful for other people too.

For example, a single Data will be protein1 graph, centre_node str, protein2 graph, label.

For each training example, this residue of interest could also be stored in the Data object and used for downstream processing

(in my case, selecting a subgraph of protein1 using coordinates of that residue of interest; although this will be in the pre-processing part before the g: nx.Graph is converted to a pytorch object.

@a-r-j
Copy link
Owner

a-r-j commented Oct 31, 2022

I see. It's probably best to set it up to accept arbitrary additional data (like centre_node) instead of hardcoding a specific use case. Sounds like an interesting application!

Let me know if you want any help on the PR :)

@a-r-j a-r-j added enhancement New feature or request 1 - Priority P1 High Priority ML labels Nov 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1 - Priority P1 High Priority enhancement New feature or request ML
Projects
None yet
Development

No branches or pull requests

2 participants