Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atomic encodings #861

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Atomic encodings #861

wants to merge 1 commit into from

Conversation

KnathanM
Copy link
Contributor

@KnathanM KnathanM commented May 9, 2024

Description

#722 asked about getting an interface for atomic encodings. This is simple to implement and may not be entirely necessary because users can do it themselves as @davidegraff showed in a comment. But in v2.2 we plan to add support for atom and bond targets which would probably benefit from a function at the model level to get the atom encodings. We could add it now, given that it would be simple.

Example

atom_encodings = []
for batch in dataloader:
    bmg, V_d, *_ = batch
    atom_encodings.extend(model.atomic_encodings(bmg, V_d)

Each element in atom_encodings is the atomic fingerprints for a single molecule.

Questions

What is a good name for this function? I brainstormed atom_encodings, atomic_encoding, atoms_fingerprint.

I still have a couple things to do before merging this if the reviews are positive.

Checklist

  • Add unit tests
  • Add example in notebooks

@@ -129,6 +129,10 @@ def encoding(
"""Calculate the :attr:`i`-th hidden representation"""
return self.predictor.encode(self.fingerprint(bmg, V_d, X_d), i)

def atomic_encodings(self, bmg: BatchMolGraph, V_d: Tensor | None = None) -> tuple[Tensor]:
H_v = self.message_passing(bmg, V_d)
return H_v.split(torch.bincount(bmg.batch).tolist())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return H_v.split(torch.bincount(bmg.batch).tolist())
sizes = torch.bincount(bmg.batch, minlength=len(bmg)).tolist()
return H_v.split(sizes)

self, bmgs: Iterable[BatchMolGraph], V_ds: Iterable[Tensor | None]
) -> list[tuple[Tensor]]:
H_vs: list[Tensor] = self.message_passing(bmgs, V_ds)
return [H_v.split(torch.bincount(bmgs.batch).tolist()) for H_v in H_vs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@davidegraff
Copy link
Contributor

FWIW I prefer atom_encodings. I prefer to avoid the term "fingerprint" as much as possible as it's not defined in the CS community, but the chemists occasionally like to use it so that's why it's been kept around. You'll also want to add a unit test or two based on the output shape. There are two strategies here:

  1. keep the code as-is and unit test that the function, given an input batch with some number of molecules produces a list[Tensor] of shape $(b, n_{a,i}, d)$, where $n_{a,i}$ is the number atoms in molecule $i$
  2. refactor this function to call some other utility function split_into_tensors() (or some other descriptive name) that handles the operation of taking a tensor of shape $(n, d)$ a batch index $\mathbf i \in [0 .. b]^{n}$ and splits it into a list[Tensor] of shape $(b, n_{a,i}, d)$, where "..."

The advantage of (2) is that it's simpler and more isolated. The unit test only needs to cover the split_into_tensors() function (as testing atom_encodings() would be an "integration" test now). In comparison, (1) can result in failures for a number of reasons, with only one of them being the fault of the atom_encodings() function (because it relies featurization and message-passing to work correctly.) The fix to this is to "mock" the output of message_passing to be correct by construction. I'm ambivalent between the two approaches: mocking (1) and refactoring (2).

@kevingreenman kevingreenman added this to the v2.0.1 milestone May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants