Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE]: Add option to make component order in multicomponent not matter #806

Open
KnathanM opened this issue Apr 17, 2024 · 13 comments
Open
Labels
enhancement a new feature request mlpds Issue/PR by MLPDS member (priority)
Milestone

Comments

@KnathanM
Copy link
Contributor

At the MLPDS meeting someone brought up that in multicomponent the order of components currently matters because the learned representations are concatenated. Could we add an option to make the architecture order invariant? The only way I can think of is summing/averaging the learning representations. I started some work on this for my solvent mixtures project.

In any event, there aren't many cases of multicomponent datasets where the order of components doesn't matter. Usually it is solute + solvent, or rxn + solvent. Posting an issue in case others have a similar ideas. This issue shouldn't get a milestone and can be closed if there is no discussion on it.

@KnathanM KnathanM added the enhancement a new feature request label Apr 17, 2024
@kevingreenman kevingreenman added the mlpds Issue/PR by MLPDS member (priority) label Apr 17, 2024
@kevingreenman
Copy link
Member

We could have options for both invariance and equivariance, to cover cases where the user wants the prediction to be the same regardless of order, as well as the cases where they want the prediction to be the negative if the order is switched.

@davidegraff
Copy link
Contributor

You would probably want to achieve this using some module:

class Fusion(nn.Module):
    ...
    def forward(self, *Xs: Tensor) -> Tensor:
        """Fuse the input tensors into a single output tensor"""

where you would have different implementations for the operators: $\oplus, \odot, ||$, etc. This module should be placed inside a MulticomponentMPNN and used in place of the call to torch.cat.

re: equivariance- @kevingreenman, where are you getting the negative from, i.e., why would permuting the order of the inputs negate the original output? Given some embedding matrix of multiple molecules $\mathbf X \in \mathbb R^{n \times (m \cdot d)}$, where $d$ is the latent dimension of a single molecule and $m$ is the number of molecules in each input; and an $md \times md$ permutation matrix $\mathbf \Pi$, an equivariant function $\mathbf W$ only guarantees $\mathbf X \mathbf W \mathbf \Pi \equiv \mathbf X \mathbf \Pi \mathbf W$. The permutation action $\mathbf \Pi$ doesn't necessarily negate the original input $\mathbf X$, so why would it negate the downstream output?

@kevingreenman
Copy link
Member

re: equivariance- @kevingreenman, where are you getting the negative from, i.e., why would permuting the order of the inputs negate the original output? Given some embedding matrix of multiple molecules X∈Rn×(m⋅d), where d is the latent dimension of a single molecule and m is the number of molecules in each input; and an md×md permutation matrix Π, an equivariant function W only guarantees XWΠ≡XΠW. The permutation action Π doesn't necessarily negate the original input X, so why would it negate the downstream output?

@davidegraff good point. I'm referring to the specific case of $m = 2$, so instead of a general permutation, we'd be trying to guarantee that $f(−X)=−f(X)$. I guess antisymmetric would be a more appropriate description than equivariant.

@davidegraff
Copy link
Contributor

davidegraff commented Apr 17, 2024

But in the case of permuting an input of $m=2$ components, where $\mathbf X = [ \mathbf X_1, \mathbf X_2 ]$, $\mathbf X \mathbf \Pi \coloneqq [ \mathbf X_2, \mathbf X_1 ]$, I'm still missing where your desire/expectation of antisymmetry expectation is coming from. If I understand correctly, your desire is that $f(\mathbf X \mathbf \Pi) = -f(\mathbf X)$, but I'm not familiar with architectures that can accomplish this. Though I might be missing something here

@kevingreenman
Copy link
Member

kevingreenman commented Apr 18, 2024

Could taking the difference or cross product between the embeddings (instead of a sum or average in the case of invariance) achieve the antisymmetry?

For context, I'm trying to think of whether there's a way for us to satisfy some of the "guarantees" from approaches like DeepDelta (which is just a wrapper that calls chemprop with --number_of_molecules 2) that currently require significant data augmentation. I think predicting property differences is the most obvious example where this type of antisymmetry would be useful, but maybe there are others too.

@kevingreenman
Copy link
Member

Could taking the difference or cross product between the embeddings (instead of a sum or average in the case of invariance) achieve the antisymmetry?

Thinking about it more, I realize that doesn't make sense. Those things would ensure the encoding is antisymmetric, but that would give no guarantees about the output after the encoding goes through the feedforward network.

@davidegraff
Copy link
Contributor

Yeah exactly, there’s nothing in the DeepDelta architecture that guarantees antisymmetry. They’re just relying on typical data augmentation to approximate an antisymmetric function. FWIW, I don’t see why we would want this in Chemprop. The measured property should be invariant to the order of our components (i.e., symmetric), as opposed to pairwise property differences, which should be antisymmetric.

@cjmcgill
Copy link
Contributor

Not a straightforward problem to solve, been thinking about it a lot lately. If you sum the fingerprints before FFN you lose some resolution on the data. What if there's a benefit to seeing how dissimilar two molecules are, now you don't really know. Talked with @cbilodeau2 about how to deal with this a while ago. The solution she landed on is in this linked paper, averaging the FFN output. It's not easily extensible to multicomponent systems. But if you are willing to limit to 2 components I think it's a very good approach.
https://www.sciencedirect.com/science/article/pii/S1385894723011853

@davidegraff
Copy link
Contributor

There are a variety equivariant pooling techniques to choose from (notably, not $\mathtt{cat}$), and it's an empiricism deciding which to use. If I'm reading the above correctly, it's a specific implementation of Janossy pooling,1 for $m=2$ but you can extend it to arbitrary $m$ and make it equivariant by averaging all $m!$ permutations. Though that's not to say you actually need true equivariance for equivariant problems- approximation is sometimes good enough. Buterez et al.2 looked at Janossy pooling in the context of node aggregation during message passing and found that neural aggregation functions are generally better, but there's no provable reason why and it could just come down to the network's improved ability to overfit.

Footnotes

  1. https://arxiv.org/abs/1811.01900

  2. https://arxiv.org/pdf/2211.04952.pdf

@cjmcgill
Copy link
Contributor

Thanks for the reference. This does seem to be an instance of Janossy. Glad to have a formal name for it now. Yes you can use more combinations to average is fine to extend it, but I'd be a bit stumped how to code that practically into Chemprop as other than a for loop with hardcoded options for 2,3,4 components.

@davidegraff
Copy link
Contributor

davidegraff commented Apr 19, 2024

re: Python implementation

class Fusion(nn.Module):
    def forward(self, *Xs: Tensor) -> Tensor:
        """
        Parameters
        -----------
        *Xs: tuple[Tensor, ...]
            a tuple of tensors of shape `n x d` containing the aggregated feature representations,
            where `n` is the batch size and `d` is the embedding dimensionality

        Returns
        -------
        Tensor
            A tensor of shape `n x *` containing the aggregated combined feature representations
        """

class JanossyFusion(Fusion):
    def __init__(self, mlp: nn.Module, k: int) -> None:
        self.mlp = mlp
        self.k = k

    def forward(self, *Xs: Tensor) -> Tensor:
        Zs = []
        for pi in sample(list(permutations(len(Xs)), self.k)):
            pi = torch.randperm(len(Xs))
            X = torch.cat([Xs[i] for i in pi], dim=1)
            Z = self.mlp(X)
            Zs.append(Z)
        Z = torch.stack(Zs, 0).mean(0)

        return Z

@kevingreenman kevingreenman added this to the v2.1.0 milestone Apr 22, 2024
@shihchengli
Copy link
Contributor

I used Janossy pooling for bond property prediction, as each bond has two direct bonds. The implementation generally looks good to me. However, I think the code for the for loop is incorrect. The input for the permutations function should be a sequence, and sample with permutations and torch.randperm seem to have similar functionality to me. If I understand the code correctly:

        for pi in sample(list(permutations(len(Xs)), self.k)):
            pi = torch.randperm(len(Xs))

can be modified as

        for pi in sample(list(permutations(range(len(Xs)))), self.k):

or

        for pi in range(self.k):
                pi = torch.randperm(len(Xs))

The second method cannot guarantee that the samples will not be repeated, so the first one looks better.

@davidegraff
Copy link
Contributor

davidegraff commented Apr 24, 2024

There was in error in the original code block (the perils of writing code in markdown on your phone!) The for-loop should read:

    def forward(self, *Xs: Tensor) -> Tensor:
        Zs = []
        all_permutations = list(permutations(len(Xs)))
        for pi in sample(all_permutations, self.k):
            X = torch.cat([Xs[i] for i in pi], dim=1)
            Z = self.mlp(X)
            Zs.append(Z)
        Z = torch.stack(Zs, 0).mean(0)
       
          return Z

where pi is a permutation of the indices of the endings, i.e., shuffling the ordering of the embeddings. You can’t simply use torch.randperm repeatedly because it only gives you one of all possible permutations, and this is a problem if you want to repeatedly sample different permutations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement a new feature request mlpds Issue/PR by MLPDS member (priority)
Projects
None yet
Development

No branches or pull requests

5 participants