Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Adding support for improvement method #174

Merged
merged 5 commits into from
May 28, 2024
Merged

Conversation

yining043
Copy link
Contributor

@yining043 yining043 commented May 12, 2024

Description

This PR is to make RL4CO support the improvement method for VRPs. The changes includes:

  • Add new features for improvement methods:
    • improvement env classes
    • improvement base model, encoder, and decoder classes
    • Positional encoding/embedding support
  • Minor changes to augment existing features
    • nn.attention, nn.ops
    • decoding strategy
    • a new type of normalization method
  • ADD the N2S implementation
  • modify the PPO training part for improvement methods

Motivation and Context

This PR is to make RL4CO support the improvement methods.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly (to do).

@yining043 yining043 changed the title Adding support for improvement method Adding support for improvement method (draft version) May 12, 2024
@fedebotu fedebotu requested review from fedebotu, cbhua and LTluttmann and removed request for fedebotu May 12, 2024 12:38
Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job 🚀

Left some comments~

rl4co/envs/routing/pdp/env.py Outdated Show resolved Hide resolved
rl4co/envs/routing/pdp/env.py Outdated Show resolved Hide resolved
rl4co/envs/routing/pdp/env.py Outdated Show resolved Hide resolved
rl4co/envs/routing/pdp/generator.py Outdated Show resolved Hide resolved
rl4co/models/common/improvement/base.py Show resolved Hide resolved
def forward(self, x):
if isinstance(self.normalizer, nn.BatchNorm1d):
return self.normalizer(x.view(-1, x.size(-1))).view(*x.size())
elif isinstance(self.normalizer, nn.InstanceNorm1d):
return self.normalizer(x.permute(0, 2, 1)).permute(0, 2, 1)
elif self.normalizer == 'layer':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we initialize the LayerNorm from PyTorch instead?

Also @cbhua @LTluttmann it would be a good idea to allow for different normalizations - if user passes str, we should be able to recover the type from PyTorch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky!! Basically the idea is exactly the LayerNorm. But the PyTorch LayerNorm requires pre-defining the shape of the tensors to norm, so in our case, we need to know the graph_size (the mean and var are computed wrt both graph_size and embed_dim). So i write my own normalisation. Also, the API for LayerNorm is different from other, e.g., it uses args like elementwise_affine rather than affine

rl4co/models/nn/pos_embeddings.py Outdated Show resolved Hide resolved
rl4co/models/zoo/n2s/encoder.py Outdated Show resolved Hide resolved
self.actions.append(selected_action)
self.logprobs.append(logprobs)
return td
# skip this step for improvement methods, since the action for improvement methods is finalized in its own policy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Internal comment] @cbhua @LTluttmann
this may be needed in other occasions as well. We may want to standardize the API

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, maybe we can rename this step() -> _step() and adding a step() wrapper outside. The _step() calculates the logprobs, selected_actions and step() selects variables to return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree with what @cbhua suggested. We can do it! Currently, I have not performed any refactoring here for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option might be to define another non-internal function like select_action() which only returns logprobs, selected_actions. And the step() method additionally adds the action to the internal list and the tensordict. Might be more readable then having yet another wrapper with a step()func ^.^

Copy link
Member

@cbhua cbhua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! 🚀 Played with the code and it works pretty efficiently! Loveit.

Oops, I should start a review firstly and then comment. I made quite a few separate comments. Hope it not to be so massive. 🤪

@yining043
Copy link
Contributor Author

Hi @cbhua @fedebotu , thank you so much for the review and great suggestions! I have replied above. Since last time I forgot to perform the pre-commit so I did a forced re-commit of the files. Sorry if it looks a bit hard to track the changes. I will perform new commits for future updates! :)

Copy link
Contributor

@LTluttmann LTluttmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work and great addition to rl4co!

rl4co/models/nn/pos_embeddings.py Outdated Show resolved Hide resolved
self.actions.append(selected_action)
self.logprobs.append(logprobs)
return td
# skip this step for improvement methods, since the action for improvement methods is finalized in its own policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option might be to define another non-internal function like select_action() which only returns logprobs, selected_actions. And the step() method additionally adds the action to the internal list and the tensordict. Might be more readable then having yet another wrapper with a step()func ^.^

rl4co/models/nn/env_embeddings/init.py Outdated Show resolved Hide resolved
rl4co/models/nn/improvement_attention.py Outdated Show resolved Hide resolved
rl4co/models/nn/improvement_attention.py Outdated Show resolved Hide resolved
@yining043
Copy link
Contributor Author

yining043 commented May 13, 2024

Hi @LTluttmann, thanks for the review! I have changed the codes in the latest commit! I marked this pull request as draft since I need to add more features before merging~

@yining043 yining043 marked this pull request as draft May 13, 2024 14:53
@yining043 yining043 marked this pull request as ready for review May 27, 2024 12:07
@fedebotu fedebotu self-requested a review May 27, 2024 14:31
Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job !

I went through the code and I don't have particular comments since you mentioned this is working well, except in the future we should maybe document a little more the new classes (but not now)

Two comments:

  1. slightly more pressing matter:
    Could you make a simple test like this so that RL4CO can automatically detect if there are any problems?
  2. I noticed that there are some conflicts that do not allow merging, this can be solved like this, let me know if you need help!


class MLP(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
num_neurons: List[int] = [64, 32],
dropout_probs: Union[None, List[float]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] Note that we could also use the MLP from TorchRL, but no need to change now since here we can add more custom stuff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fedebotu, thanks for the comments! I have added the tests for N2S and resolved the conflicts!

Copy link
Member

@cbhua cbhua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! 🚀 Left some random comments. Really minor if the model is properly working 😂

rl4co/envs/routing/pdp/env.py Outdated Show resolved Hide resolved
rl4co/envs/routing/pdp/env.py Show resolved Hide resolved
@yining043 yining043 changed the title Adding support for improvement method (draft version) [Feat] Adding support for improvement method May 28, 2024
@fedebotu fedebotu self-requested a review May 28, 2024 10:01
Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

Copy link
Member

@cbhua cbhua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@cbhua cbhua merged commit 3542f8e into ai4co:main May 28, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants