-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Add JSSP environment #177
Conversation
from .generator import JSSPFileGenerator, JSSPGenerator | ||
|
||
|
||
class JSSPEnv(FJSPEnv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I love that the code is being reused in such as smart way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yesss, when see this inherit: wow nice.
rl4co/models/__init__.py
Outdated
@@ -19,7 +19,7 @@ | |||
from rl4co.models.rl.ppo.ppo import PPO | |||
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline | |||
from rl4co.models.rl.reinforce.reinforce import REINFORCE | |||
from rl4co.models.zoo import HetGNNModel | |||
from rl4co.models.zoo import L2DModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good - let's make sure the baselines have their names! L2D is a very influential paper in the NCO community
@@ -73,7 +73,8 @@ def gather_by_index(src, idx, dim=1, squeeze=True): | |||
expanded_shape = list(src.shape) | |||
expanded_shape[dim] = -1 | |||
idx = idx.view(idx.shape + (1,) * (src.dim() - idx.dim())).expand(expanded_shape) | |||
return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx) | |||
squeeze = idx.size(dim) == 1 and squeeze | |||
return src.gather(dim, idx).squeeze(dim) if squeeze else src.gather(dim, idx) | |||
|
|||
|
|||
def unbatchify_and_gather(x: Tensor, idx: Tensor, n: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] it might be faster if you put a @torch.jit
decorator. Not 100% sure though
@Junyoungpark tagging you since you know this problem very well! Do you think we have a chance to include ScheduleNet? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job!
# update adjacency matrices (remove edges) | ||
td["proc_times"] = td["proc_times"].scatter( | ||
2, | ||
selected_op[:, None, None].expand(-1, self.num_mas, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor, Enhancement] Using einops.repeat
could be "slightly" more efficient 😁:
repeat(selected_op, 'b -> b n d', n=self.num_mas, d=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it though? I think if you know already the dimensions, einops
is slightly slower from what I know, but take this with a grain of salt. But I agree that it's more readable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I did a trial and actually einops.repeat
is way slower than tensor.expand
😂 (in large scale around 4x slower). Then I think it's good to keep using the tensor.expand
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think they use very clever optimizations in torch.expand() 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] Maybe we don't need this file for clean file structure.
from .generator import JSSPFileGenerator, JSSPGenerator | ||
|
||
|
||
class JSSPEnv(FJSPEnv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yesss, when see this inherit: wow nice.
# self-loop is added by GCNConv layer | ||
return get_full_graph_edge_index(td.device, num_nodes, self_loop=False) | ||
|
||
|
||
class GCNEncoder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this clean refactoring, the logic is clearer. But will the get_full_graph_edge_index()
be called at every forward step? i.e. in the previous version, if it's a fully connected connected graph, the edge_index
will be saved as class variable, instead of regenerating every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats true, but the result is cached so it should not be too slow. In fact, this implementation should be much faster as before (at least it was in my experiments), because it avoids the list comprehension over the batch data within the forward pass. But I agree, its still not optimal; I will revisit this in the near future
…and also adds self attn.
… attn class to nn/attention.
…etter training stability
Thanks for reviewing guys. I will add a ton of changes here in a couple of minutes and I hope this PR will not get too messy. Let me know if we should go through it together. Additional changes:
|
…by pytorch geometric
Wow, lots of changes here! 😁 Really curious about episodic / stepwise RL performances Btw, feel free to merge anytime |
rl4co/envs/routing/tsp/env.py
Outdated
# NOTE Experimental TSP class for stepwise PPO | ||
|
||
|
||
class TSPEnv4PPO(TSPEnv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] this may be called DenseRewardTSPEnv
or similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Btw., Stepwise PPO for TSP indeed converges to the nearest neighbor heuristic, at least with the stepwise reward as it is defined here (the distance added by the action):
Do you have a preferrence as to how to call the stepwise PPO in the paper (dense, stepwise something else)? And then we should probably adjust the description about PPO in the appendix
@@ -58,7 +57,7 @@ def __init__( | |||
generator_params: dict = {}, | |||
**kwargs, | |||
): | |||
super().__init__(**kwargs) | |||
super().__init__(check_solution=False, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always the case (no solution check)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhm I think this is there just bc there is no check implemented for FFSP yet haha. Let me see if I can get one implemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries, this is not a pressing issue! Actually, it's even better to keep it to False
during training for efficiency
Great job!! This PR is truly huge |
Description
Motivation and Context
JSSP is a common CO problem and a widely used problem to benchmark new algorithms
I have raised an issue to propose this change (required for new features and bug fixes)
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!