rl4co
rl4co copied to clipboard
Convert TSP to ATSP
Describe the bug
I used the notebook in the link below to learn about rl4co(https://github.com/ai4co/rl4co/blob/main/notebooks/tutorials/2-creating-new-env-model.ipynb). I now want to verify the ATSP method, so I import ATSPEnv instead of TSPEnv like this:
batch_size = 2
from rl4co.envs import ATSPEnv
env_atsp = ATSPEnv(num_loc=30)
reward, td, actions = rollout(env_atsp, env_atsp.reset(batch_size=[batch_size]), random_policy)
env_atsp.render(td, actions)
which run correctly but when I Rollout untrained model like below, I encounter the following bugs:
Greedy rollouts over untrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init_atsp = env_atsp.reset(batch_size=[3]).to(device)
model_atsp = model_atsp.to(device)
out_atsp = model_atsp(td_init_atsp.clone(), phase="test", decode_type="greedy", return_actions=True)
actions_untrained = out_atsp['actions'].cpu().detach()
rewards_untrained = out_atsp['reward'].cpu().detach()
for i in range(3):
print(f"Problem {i+1} | Cost: {-rewards_untrained[i]:.3f}")
env_atsp.render(td_init_atsp[i], actions_untrained[i])
bugs are:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[5], line 5
3 td_init_atsp = env_atsp.reset(batch_size=[3]).to(device)
4 model_atsp = model_atsp.to(device)
----> 5 out_atsp = model_atsp(td_init_atsp.clone(), phase="test", decode_type="greedy", return_actions=True)
6 actions_untrained = out_atsp['actions'].cpu().detach()
7 rewards_untrained = out_atsp['reward'].cpu().detach()
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/rl/common/base.py:246, in RL4COLitModule.forward(self, td, **kwargs)
244 log.info("Using env from kwargs")
245 env = kwargs.pop("env")
--> 246 return self.policy(td, env, **kwargs)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/zoo/common/autoregressive/policy.py:140, in AutoregressivePolicy.forward(self, td, env, phase, return_actions, return_entropy, return_init_embeds, **decoder_kwargs)
125 """Forward pass of the policy.
126
127 Args:
(...)
136 out: Dictionary containing the reward, log likelihood, and optionally the actions and entropy
137 """
139 # ENCODER: get embeddings from initial state
--> 140 embeddings, init_embeds = self.encoder(td)
142 # Instantiate environment if needed
143 if isinstance(env, str) or env is None:
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/zoo/common/autoregressive/encoder.py:74, in GraphAttentionEncoder.forward(self, td, mask)
62 """Forward pass of the encoder.
63 Transform the input TensorDict into a latent representation.
64
(...)
71 init_h: Initial embedding of the input
72 """
73 # Transfer to embedding space
---> 74 init_h = self.init_embedding(td)
76 # Process embedding
77 h = self.net(init_h, mask)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/nn/env_embeddings/init.py:49, in TSPInitEmbedding.forward(self, td)
48 def forward(self, td):
---> 49 out = self.init_embed(td["locs"])
50 return out
File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:3697, in TensorDictBase.__getitem__(self, index)
3695 idx_unravel = _unravel_key_to_tuple(index)
3696 if idx_unravel:
-> 3697 return self._get_tuple(idx_unravel, NO_DEFAULT)
3698 if (istuple and not index) or (not istuple and index is Ellipsis):
3699 # empty tuple returns self
3700 return self
File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:4625, in TensorDict._get_tuple(self, key, default)
4624 def _get_tuple(self, key, default):
-> 4625 first = self._get_str(key[0], default)
4626 if len(key) == 1 or first is default:
4627 return first
File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:4621, in TensorDict._get_str(self, key, default)
4619 out = self._tensordict.get(first_key, None)
4620 if out is None:
-> 4621 return self._default_get(first_key, default)
4622 return out
File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:1455, in TensorDictBase._default_get(self, key, default)
1452 return default
1453 else:
1454 # raise KeyError
-> 1455 raise KeyError(
1456 TensorDictBase.KEY_ERROR.format(
1457 key, self.__class__.__name__, sorted(self.keys())
1458 )
1459 )
KeyError: 'key "locs" not found in TensorDict with keys [\'action_mask\', \'cost_matrix\', \'current_node\', \'done\', \'first_node\', \'i\', \'terminated\']'
Reason and Possible fixes
I think the problem is the mismatch between model and ATSPEnv, but I have not found a solution. Thank you for your time and attention
By the way, how should I train an ATSP model like a TSP model
I think the problem comes from https://github.com/ai4co/rl4co/blob/1a2da37d6104c33646f74bb4b040d2a4006876c2/rl4co/models/nn/env_embeddings/init.py#L16 Basically by default, the same InitEmbedding used for TSP is used for the ATSP environment. The issue is that in TSP you can just embed the coordinates of each node ('locs' in the TSP environment) and make the encoder infer the euclidean distance, while in the ATSP I think you can't because all you have is an asymmetric distance matrix ('cost_matrix' in the ATSP environment) and giving the encoder the coordinates of each node would not help it understand why going from one to the other has a cost and going back has another.
So I think that in order to solve the ATSP problem with the AM model you need a custom InitEmbedding that encodes the nodes in such a way that you also provide information about the asymmetric distance matrix. Maybe a GNN or something like that.
Hi @Mu-Yanchen, thanks for raising this bug and sorry for our late reply. Also thanks to @Haimrich's help!
In the current version, we applied the MetNet[1] on the ATSP. Different from other environments, the initial embedding for ATSP is located at here.
We updated the MatNet implementation in b3f1446820fd6c2d9ac3399369ffc134dd86b3ab. You may want to check a minimum testing on this notebook and play with it 🚀.
[1] Kwon, Yeong-Dae, et al. "Matrix encoding networks for neural combinatorial optimization." Advances in Neural Information Processing Systems 34 (2021): 5138-5149.
Closing now. Feel free to reopen if any other issue arises! :+1: