Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert TSP to ATSP #108

Closed
Mu-Yanchen opened this issue Jan 1, 2024 · 4 comments
Closed

Convert TSP to ATSP #108

Mu-Yanchen opened this issue Jan 1, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@Mu-Yanchen
Copy link

Describe the bug

I used the notebook in the link below to learn about rl4co(https://github.com/ai4co/rl4co/blob/main/notebooks/tutorials/2-creating-new-env-model.ipynb). I now want to verify the ATSP method, so I import ATSPEnv instead of TSPEnv like this:

batch_size = 2
from rl4co.envs import ATSPEnv
env_atsp = ATSPEnv(num_loc=30)
reward, td, actions = rollout(env_atsp, env_atsp.reset(batch_size=[batch_size]), random_policy)
env_atsp.render(td, actions)

which run correctly but when I Rollout untrained model like below, I encounter the following bugs:

Greedy rollouts over untrained model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init_atsp = env_atsp.reset(batch_size=[3]).to(device)
model_atsp = model_atsp.to(device)
out_atsp = model_atsp(td_init_atsp.clone(), phase="test", decode_type="greedy", return_actions=True)
actions_untrained = out_atsp['actions'].cpu().detach()
rewards_untrained = out_atsp['reward'].cpu().detach()

for i in range(3):
    print(f"Problem {i+1} | Cost: {-rewards_untrained[i]:.3f}")
    env_atsp.render(td_init_atsp[i], actions_untrained[i])

bugs are:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[5], line 5
      3 td_init_atsp = env_atsp.reset(batch_size=[3]).to(device)
      4 model_atsp = model_atsp.to(device)
----> 5 out_atsp = model_atsp(td_init_atsp.clone(), phase="test", decode_type="greedy", return_actions=True)
      6 actions_untrained = out_atsp['actions'].cpu().detach()
      7 rewards_untrained = out_atsp['reward'].cpu().detach()

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/rl/common/base.py:246, in RL4COLitModule.forward(self, td, **kwargs)
    244     log.info("Using env from kwargs")
    245     env = kwargs.pop("env")
--> 246 return self.policy(td, env, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/zoo/common/autoregressive/policy.py:140, in AutoregressivePolicy.forward(self, td, env, phase, return_actions, return_entropy, return_init_embeds, **decoder_kwargs)
    125 """Forward pass of the policy.
    126 
    127 Args:
   (...)
    136     out: Dictionary containing the reward, log likelihood, and optionally the actions and entropy
    137 """
    139 # ENCODER: get embeddings from initial state
--> 140 embeddings, init_embeds = self.encoder(td)
    142 # Instantiate environment if needed
    143 if isinstance(env, str) or env is None:

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/zoo/common/autoregressive/encoder.py:74, in GraphAttentionEncoder.forward(self, td, mask)
     62 """Forward pass of the encoder.
     63 Transform the input TensorDict into a latent representation.
     64 
   (...)
     71     init_h: Initial embedding of the input
     72 """
     73 # Transfer to embedding space
---> 74 init_h = self.init_embedding(td)
     76 # Process embedding
     77 h = self.net(init_h, mask)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/nn/env_embeddings/init.py:49, in TSPInitEmbedding.forward(self, td)
     48 def forward(self, td):
---> 49     out = self.init_embed(td["locs"])
     50     return out

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:3697, in TensorDictBase.__getitem__(self, index)
   3695     idx_unravel = _unravel_key_to_tuple(index)
   3696     if idx_unravel:
-> 3697         return self._get_tuple(idx_unravel, NO_DEFAULT)
   3698 if (istuple and not index) or (not istuple and index is Ellipsis):
   3699     # empty tuple returns self
   3700     return self

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:4625, in TensorDict._get_tuple(self, key, default)
   4624 def _get_tuple(self, key, default):
-> 4625     first = self._get_str(key[0], default)
   4626     if len(key) == 1 or first is default:
   4627         return first

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:4621, in TensorDict._get_str(self, key, default)
   4619 out = self._tensordict.get(first_key, None)
   4620 if out is None:
-> 4621     return self._default_get(first_key, default)
   4622 return out

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:1455, in TensorDictBase._default_get(self, key, default)
   1452     return default
   1453 else:
   1454     # raise KeyError
-> 1455     raise KeyError(
   1456         TensorDictBase.KEY_ERROR.format(
   1457             key, self.__class__.__name__, sorted(self.keys())
   1458         )
   1459     )

KeyError: 'key "locs" not found in TensorDict with keys [\'action_mask\', \'cost_matrix\', \'current_node\', \'done\', \'first_node\', \'i\', \'terminated\']'

Reason and Possible fixes

I think the problem is the mismatch between model and ATSPEnv, but I have not found a solution. Thank you for your time and attention

@Mu-Yanchen Mu-Yanchen added the bug Something isn't working label Jan 1, 2024
@Mu-Yanchen
Copy link
Author

By the way, how should I train an ATSP model like a TSP model

@Haimrich
Copy link
Contributor

I think the problem comes from

"atsp": TSPInitEmbedding,

Basically by default, the same InitEmbedding used for TSP is used for the ATSP environment. The issue is that in TSP you can just embed the coordinates of each node ('locs' in the TSP environment) and make the encoder infer the euclidean distance, while in the ATSP I think you can't because all you have is an asymmetric distance matrix ('cost_matrix' in the ATSP environment) and giving the encoder the coordinates of each node would not help it understand why going from one to the other has a cost and going back has another.

So I think that in order to solve the ATSP problem with the AM model you need a custom InitEmbedding that encodes the nodes in such a way that you also provide information about the asymmetric distance matrix. Maybe a GNN or something like that.

@fedebotu fedebotu assigned Junyoungpark and unassigned fedebotu Jan 22, 2024
@cbhua
Copy link
Member

cbhua commented Feb 26, 2024

Hi @Mu-Yanchen, thanks for raising this bug and sorry for our late reply. Also thanks to @Haimrich's help!

In the current version, we applied the MetNet[1] on the ATSP. Different from other environments, the initial embedding for ATSP is located at here.

We updated the MatNet implementation in b3f1446. You may want to check a minimum testing on this notebook and play with it 🚀.

[1] Kwon, Yeong-Dae, et al. "Matrix encoding networks for neural combinatorial optimization." Advances in Neural Information Processing Systems 34 (2021): 5138-5149.

@cbhua cbhua self-assigned this Feb 26, 2024
fedebotu added a commit that referenced this issue Feb 26, 2024
@fedebotu
Copy link
Member

fedebotu commented Jun 8, 2024

Closing now. Feel free to reopen if any other issue arises! 👍

@fedebotu fedebotu closed this as completed Jun 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants