Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The number of tokens is inconsistent with the tokens provided in LogP_optimization_demo.ipynb, and cannot be learned using migration? #30

Open
zhouhao-learning opened this issue May 31, 2019 · 4 comments

Comments

@zhouhao-learning
Copy link

Hello, when I train a generate model with my own SMILES data, use LogP_optimization_demo.ipynb:
tokens = ['<', '>', '#', '%', ')', '(', '+', '-', '/', '.', '1', '0 ', '3', '2', '5', '4', '7', '6', '9', '8', '=', 'A', '@', 'C', 'B', 'F', 'I', 'H', 'O', 'N', 'P', 'S', '[', ']', '\\', 'c', ' e', 'i', 'l', 'o', 'n', 'p', 's', 'r', '\n'], but will get characters outside the tokens list, causing me to fail Continue to use the Transfer learning method to train, so I changed the code as follows during training:

gen_data_path = "data/nueji_data2.csv"
gen_data = GeneratorData(training_data_path=gen_data_path, delimiter='\t', 
                         cols_to_read=[0], keep_header=True, tokens=None)
hidden_size = 1500
stack_width = 1500
stack_depth = 200
layer_type = 'GRU'
lr = 0.001
optimizer_instance = torch.optim.Adadelta

my_generator = StackAugmentedRNN(input_size=gen_data.n_characters, hidden_size=hidden_size,
                                 output_size=gen_data.n_characters, layer_type=layer_type,
                                 n_layers=1, is_bidirectional=False, has_stack=True,
                                 stack_width=stack_width, stack_depth=stack_depth, 
                                 use_cuda=use_cuda, 
                                 optimizer_instance=optimizer_instance, lr=lr)
model_path = './checkpoints/generator/checkpoint_biggest_rnn'
my_generator.load_model(model_path)

But I get the following error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-11-3c9498b26c8c> in <module>()
----> 1 my_generator.load_model(model_path)

/scratch2/hzhou/Drug/generate_smiles/ReLeaSE/release/stackRNN.py in load_model(self, path)
    140         """
    141         weights = torch.load(path)
--> 142         self.load_state_dict(weights)
    143 
    144     def save_model(self, path):

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    717         if len(error_msgs) > 0:
    718             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 719                                self.__class__.__name__, "\n\t".join(error_msgs)))
    720 
    721     def parameters(self):

RuntimeError: Error(s) in loading state_dict for StackAugmentedRNN:
	size mismatch for encoder.weight: copying a param of torch.Size([40, 1500]) from checkpoint, where the shape is torch.Size([45, 1500]) in current model.
	size mismatch for decoder.weight: copying a param of torch.Size([40, 1500]) from checkpoint, where the shape is torch.Size([45, 1500]) in current model.
	size mismatch for decoder.bias: copying a param of torch.Size([40]) from checkpoint, where the shape is torch.Size([45]) in current model.

But my data set is very small. Without migration learning, my generation model may not be able to learn the chemical rules of SMILES, so my idea is this: I use the `data/chembl_22_clean_1576904_sorted_std_final.smi'data set to retrain a model, but I customize tokens to define the characters in my data set into token, and finally make it work again. Re-training my data with a pre-training model, is my idea right? I'm not sure.

@isayev
Copy link
Owner

isayev commented Jun 1, 2019

What kinds of extra characters do you have? You probably need to standardize your SMILEs (remove metals, mixtures, stereochemistry, etc.).

@zhouhao-learning
Copy link
Author

@isayev
My SMILES contains extra characters a, because the characters contain Na, Ca, what do you mean by standardized SMILES? What do I need to do? Thank you

@quangnguyenbn99
Copy link

hi @zhouhao-learning ,
Did you solve your problem? I am facing the same issue. I you have the solution please enlighten me.

@gmseabra
Copy link

hi @zhouhao-learning ,
Did you solve your problem? I am facing the same issue. I you have the solution please enlighten me.

Although the question is old, I'm answering it now because it seems it still unresolved...

Basically, the point is that you generally don't want ions (Na+, Ca2+) in your compound library, since they are just counterions to your compound. So, you need to remove those from the your SMILES data before using it.

Take a look at: https://molvs.readthedocs.io/en/latest/

Best.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants