Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to use pre-trained model correctly? #163

Open
karin0018 opened this issue Dec 9, 2021 · 2 comments
Open

how to use pre-trained model correctly? #163

karin0018 opened this issue Dec 9, 2021 · 2 comments

Comments

@karin0018
Copy link

karin0018 commented Dec 9, 2021

Hi, I am trying to use pre-trained model on ESOL dataset.

from tqdm import tqdm
import dgl
from dgllife.data import ESOL
from dgllife.model import load_pretrained
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, AttentiveFPAtomFeaturizer, CanonicalBondFeaturizer, AttentiveFPBondFeaturizer

dataset_canonical = ESOL(smiles_to_bigraph, CanonicalAtomFeaturizer(),CanonicalBondFeaturizer())

model = load_pretrained('Weave_canonical_ESOL') # Pretrained model loaded
model.eval()

for smiles, g, label in tqdm(dataset_canonical):
    nfeats = g.ndata['h']
    efeats = g.edata['e']
    label_pred = model(g, nfeats, efeats)
    print(label_pred)
    print(label)

This throws the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_184688/2242391364.py in <module>
      7     nfeats = g.ndata['h']
      8     efeats = g.edata['e']
----> 9     label_pred = model(g, nfeats, efeats)
     10     print(label_pred)
     11     print(label)

~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/dgl/lib/python3.9/site-packages/dgllife/model/model_zoo/weave_predictor.py in forward(self, g, node_feats, edge_feats)
    103             Prediction for the graphs in the batch. G for the number of graphs.
    104         """
--> 105         node_feats = self.gnn(g, node_feats, edge_feats, node_only=True)
    106         node_feats = self.node_to_graph(node_feats)
    107         g_feats = self.readout(g, node_feats)

~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/dgl/lib/python3.9/site-packages/dgllife/model/gnn/weave.py in forward(self, g, node_feats, edge_feats, node_only)
    208         """
    209         for i in range(len(self.gnn_layers) - 1):
--> 210             node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
    211         return self.gnn_layers[-1](g, node_feats, edge_feats, node_only)

~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/dgl/lib/python3.9/site-packages/dgllife/model/gnn/weave.py in forward(self, g, node_feats, edge_feats, node_only)
    107         # Update node features
    108         node_node_feats = self.activation(self.node_to_node(node_feats))
--> 109         g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats))
    110         g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n'))
    111         edge_node_feats = g.ndata.pop('e2n')

~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (68x12 and 13x256)

I check the shape of graph and the construction of WeavePredictor, find they are not match

>>>smiles, g, label = dataset_canonical[0]
>>>print(g.edata['e'].shape)
torch.Size([68, 12])
>>>print(model)
WeavePredictor(
  (gnn): WeaveGNN(
    (gnn_layers): ModuleList(
      (0): WeaveLayer(
        (node_to_node): Linear(in_features=74, out_features=256, bias=True)
        (edge_to_node): Linear(in_features=13, out_features=256, bias=True)
        (update_node): Linear(in_features=512, out_features=256, bias=True)
        (left_node_to_edge): Linear(in_features=74, out_features=256, bias=True)
        (right_node_to_edge): Linear(in_features=74, out_features=256, bias=True)
        (edge_to_edge): Linear(in_features=13, out_features=256, bias=True)
        (update_edge): Linear(in_features=768, out_features=256, bias=True)
      )
      ...

How can I solve this error? Thanks a lot for your help!

@mufeili
Copy link
Contributor

mufeili commented Dec 9, 2021

Can you take a look at #162 ? I believe it's the same issue.

@karin0018
Copy link
Author

Ok, that works! Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants