New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Atomic encodings #861
base: main
Are you sure you want to change the base?
Atomic encodings #861
Conversation
@@ -129,6 +129,10 @@ def encoding( | |||
"""Calculate the :attr:`i`-th hidden representation""" | |||
return self.predictor.encode(self.fingerprint(bmg, V_d, X_d), i) | |||
|
|||
def atomic_encodings(self, bmg: BatchMolGraph, V_d: Tensor | None = None) -> tuple[Tensor]: | |||
H_v = self.message_passing(bmg, V_d) | |||
return H_v.split(torch.bincount(bmg.batch).tolist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return H_v.split(torch.bincount(bmg.batch).tolist()) | |
sizes = torch.bincount(bmg.batch, minlength=len(bmg)).tolist() | |
return H_v.split(sizes) |
self, bmgs: Iterable[BatchMolGraph], V_ds: Iterable[Tensor | None] | ||
) -> list[tuple[Tensor]]: | ||
H_vs: list[Tensor] = self.message_passing(bmgs, V_ds) | ||
return [H_v.split(torch.bincount(bmgs.batch).tolist()) for H_v in H_vs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
FWIW I prefer
The advantage of (2) is that it's simpler and more isolated. The unit test only needs to cover the |
Description
#722 asked about getting an interface for atomic encodings. This is simple to implement and may not be entirely necessary because users can do it themselves as @davidegraff showed in a comment. But in v2.2 we plan to add support for atom and bond targets which would probably benefit from a function at the model level to get the atom encodings. We could add it now, given that it would be simple.
Example
Each element in atom_encodings is the atomic fingerprints for a single molecule.
Questions
What is a good name for this function? I brainstormed
atom_encodings
,atomic_encoding
,atoms_fingerprint
.I still have a couple things to do before merging this if the reviews are positive.
Checklist