Major environment refactoring (draft version)
[!IMPORTANT] The merge of this pull request is postponed because it contains sensitive modifications to the environment logic, which may cause hidden bugs. We should be careful to update them. Therefore, this full version of environment refactoring will be kept as a draft. We opened another base version refactor pull request: https://github.com/ai4co/rl4co/pull/169, which only touches the environment structure and adds the generator without changing any logic for a safe refactor in the current state. In the future, we will based on this draft's full version, go further refactor environments step by step.
Description
Together with Major modeling refactoring #165, this PR is for major, long-due refactoring to the RL4CO environments codebase.
Motivation and Context
This refactoring is driven by following motivations:
- New Feature Integration: We aim to support a data generator capable of producing various distributions for initialized instances.
- Standardization of Environments: Before our environments were developed at different times, so there were inconsistencies in content, logic, and formatting. This refactoring tried to standardize these environments.
- Code Cleanup: Our earlier versions included redundant code, functions, and calculation logic. This refactoring effort will clean up these elements, enhancing the codebase's readability and maintainability.
Changelog
Environment Structure Refactoring
The refactored structure for environments is as following:
rl4co
├── models/
└── envs/
├── eda/
├── scheduling/
└── routing/
├── tsp/
│ ├── env.py
│ ├── generator.py
│ └── render.py
├── cvrp/
│ ├── env.py
│ ├── generator.py
│ └── render.py
└── ...
We have restructured the organization of the environment files for improved modularity and clarity. Each environment has its own directory, comprising three components:
-
env.py: The core framework of the environment, managing functions such as_reset(),_step(), and others. For a comprehensive understanding, please refer to the documentation. -
generator.py: Replace the previousgenerate_data()function; this module works for randomly initializing instances within the environment. The updated version now supports custom data distributions. See the following sections for more details. -
render.py: For visualization of the solution. Its separation from the main environment file enhances overall code readability.
Data Generator Supporting
Each environment generator will be based on the base Generator() class with the following functions:
class Generator():
def __init__(self, **kwargs):
self.kwargs = kwargs
def __call__(self, batch_size) -> TensorDict:
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
return self._generate(batch_size)
def _generate(self, batch_size, **kwargs) -> TensorDict:
raise NotImplementedError
-
__init_()will record all the environment instance initialize parameters, for example,num_loc,min_loc,max_loc, etc.Thus, you will see how the
__init__()function for the environment (e.g.CVRPEnv.__init__(...)) only takesgeneratorandgenerator_paramsas input. Now, the environment initialize example would beenv = CVRPEnv(generator_params={num_loc=20}) # Another way generator = CVRPGenerator(num_loc=20) env = CVRPEnv(generator)Various samplers will be initialized here. We provide the
get_sampler()function to based on the input variables to return atorch.distributionsclass. By default, we support distributionsUniform,Normal,Exponential, andPoissonfor locations andcenter,corner, for depots. You can also pass your won distribution sampler. See the following sections for more details. -
__call__()is a middle wrapper; at the moment, it is used to regularize thebatch_sizeformat supported by the TorchRL (i.e., in alistformat). Note that in this refactor version, we would finalize the dimension ofbatch_sizeto be 1 for easier implementation and clearer understanding since even multi-batch-size dimensions can be easily transferred to a single dimension. -
__generate()is the part you would like to implement for your own environment data generator.
New get_sampler() function
This implementation mainly refers to @ngastzepeda's code. In the current version, we support the following distributions:
-
center: For depots. All depots will be initialized in the center of the space. -
corner: For depots. All depots will be initialized in the bottom left corner of the space. -
Uniform: Takesmin_valandmax_valas input. -
ExponentialandPoisson: Takemean_valandstd_valas input.
You can also use your own Callable function as the sampler. This function will take the batch_size: List[int] as input and return the sampled torch.Tensor.
Modification for RL4COEnvBase()
We move the checking for batch_size and device from every environment to the base class for clarity, as shown in
https://github.com/ai4co/rl4co/blob/b70566bc2354ade45d249a8eb86c40f0e2b47230/rl4co/envs/common/base.py#L130-L138
We added a new _get_reward() function aside from the original get_reward() function and moved the check_solution_validity() from every environment to the base class for clarity, as shown in
https://github.com/ai4co/rl4co/blob/b70566bc2354ade45d249a8eb86c40f0e2b47230/rl4co/envs/common/base.py#L175-L187
Standardization
We standardize the contents of env.py with the following functions:
class EnvName(RL4COEnvBase):
name = "env_name"
def __init__(self, generator: EnvGenerator, generator_params: dict): pass
def _step(self, td: TensorDict) -> Tensordict: pass
@staticmethod
def get_action_mask(td: TensorDict) -> torch.Tensor: pass
def _reset(self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None) -> TensorDict: pass
def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: pass
@staticmethod
def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: pass
@staticmethod
def render(td: TensorDict, actions: torch.Tensor = None, ax = None): pass
def _make_spec(self, generator: EnvGenerator): pass
The order is considered to be natural and easy to follow, and we expected all environments to follow the same order for easier reference and matinees. In more detail, we have the following standardization:
- We changed the variable name
availabletovisitedfor more intuitive understanding. In thestep()andget_action_mask()calculation,visitedrecords which nodes are visited, and theaction_maskis based on it with environment constraints (e.g., capacity, time window, etc.). Separating these two variables would be clearer for the calculation logic. - For some environments, change the
_step()function to a nonstatic method. Follow the TorchRL style. - Standardize the
get_action_mask()calculation logic, which generally contains three parts: (a) initialize theaction_maskbased onvisited; (b) update citiesaction_maskbased on the state; (c) update the depotaction_maskfinally. Based on experience, this logic would cause fewer conflicts and mass. - All 1-D features, e.g.,
i,capacity,used_capacity,etc., are initialized with the size of[*batch_size, 1]instead of[*batch_size, ]. The reason is that in many masking operations, we need to do logic calculations between this 1-D feature and 2-D features, e.g., capacity with demand. Also, stay consistent with TorchRL implementation. - Rewrite comments on environments with descriptions of observations, constraints, finish conditions, rewards, and args so that a user can better understand the environment. Also, move data-related parameters (e.g.,
num_loc,min_loc,max_loc) to the generator for clarity. - Add the
costvariable to theget_rewardfunction for an intuitive understanding. In this case, the return (reward) is-cost.
Other Fixes
- In CVRP, change the variable name
vehicle_capacity→capacity,capacity→unnorm_capacityto clarify. - [⚠️ Sensitive Change] Now, the
demandvariable will also contain the depot. For example, in the previousCVRPEnv(), givennum_loc=50, thetd[”locs”]has the size of[batch_size, 51, 2](with the depot), and thetd[”demand”]has the size of[batch_size, 50, 2]. This causes index shifting in theget_action_mask()function, which requires a few padding operations. - Fix the SDVRP environment action mask calculation bug.
- Adding numerical calculation error bound (
0→1e-5), for example, in SDVRPdone = ~(demand > 0).any(-1)→done = ~(demand > 1e-5).any(-1)for better robustness to avoid edge cases. - In CVRP, OP, and PCTSP environments, getting variables from tables with
num_loc,e.g., CVRPCAPACITIES,if the givennum_locis not in the table, we will find the closestnum_locas replace and raise a warning to increase the running robustness. - Fix the return type of
get_reward().
Notes
- In The current version, we don’t support the distribution of int values, e.g.,
num_depot,num_agents. These values are initialized bytorch.randint(). - In the reward calculation, for environments with the constraint starting and ending at the depot, actions should pad
0to the start and end. - In the current version, only routing environments have been refactored. We will also refactor the EDA and Scheduling environments soon.
Here is the summary of the refractory status for each environment:
-
Decompose: decompose environments into folder with
env.py,generator.py,render.py; fix the__init__()and_reset()functions; - Training Checking: checking the training of refactored environments;
- Documentation: cleanup and fix environment documents and logic comments;
-
Solution Validity: check if the environment contains a
check_solution_validity()function; -
Clean up Logic: check if the
_step()andget_action_maks()function are cleaned up with the standard pipeline.
| Decompose | Training Checking | Documentation | Solution Validity | Clean up Logic | |
|---|---|---|---|---|---|
| TSP | ✅ | ✅ | ✅ | ✅ | ✅ |
| CVRP | ✅ | ✅ | ✅ | ✅ | ✅ |
| CVRPTW | ✅ | ✅ | ✅ | ✅ | |
| PCTSP | ✅ | ✅ | ✅ | ✅ | |
| OP | ✅ | ✅ | ✅ | ||
| SDVRP | ✅ | ✅ | ✅ | ✅ | ✅ |
| SVRP | ✅ | ✅ | ✅ | ||
| ATSP | ✅ | ✅ | ✅ | ✅ | ✅ |
| MTSP | ✅ | ✅ | ✅ | ||
| SPCTSP | ✅ | ✅ | ✅ | ✅ | |
| PDP | ✅ | ✅ | ✅ | ✅ | ✅ |
| MPDP | ✅ | ✅ | ✅ | ||
| MDCPDP | ✅ | ✅ |
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
- [x] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds core functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to change)
- [x] Documentation (update in the documentation)
- [ ] Example (update in the folder of examples)
Checklist
- [x] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
- [ ] I have updated the documentation accordingly.
Thanks, and need your help
Thanks for @ngastzepeda's base code for this refactoring!
If you have time, welcome to provide your ideas/feedback on this PR. CC: @Furffico @henry-yeh @bokveizen @LTluttmann
There are quite a few remaining works for this PR, and I will actively update them here.
Let's remember also to fix the shifts in the torch.roll distance calculation as @ngastzepeda noticed, e.g. here. These do not affect calculations in euclidean problems, but it's best to have it conceptually correct
Notice that we moved most of the above in here #169 (without modification to environment logic or variables)! We will address the comments and merge soon~
There have been too many changes to track recently, and it seems that several features have already been added.
I will be closing this for now and come back to this for a fresh PR if needed!