ConfigSpace
ConfigSpace copied to clipboard
An implementation to support complex conditions such as 'a < b and a * b < 10 and ...'
I implement a ConditionedConfigurationSpace
that supports complex conditions between hyperparameters (e.g., x1 <= x2 and x1 * x2 < 100). User can define a sample_condition
function to restrict the generation of configurations.
The following functions are guaranteed to return valid configurations:
- self.sample_configuration()
- get_one_exchange_neighbourhood() # may return empty list
Here is an example:
def sample_condition(config):
# require x1 <= x2 and x1 * x2 < 100
if config['x1'] > config['x2']:
return False
if config['x1'] * config['x2'] >= 100:
return False
return True
# return config['x1'] <= config['x2'] and config['x1'] * config['x2'] < 100
cs = ConditionedConfigurationSpace()
cs.add_hyperparameters([...])
cs.set_sample_condition(sample_condition) # set the sample condition after all hyperparameters are added
configs = cs.sample_configuration(1000)
Implementing this feature using fobiddens like ForbiddenClause
or ForbiddenRelation
might be a viable option, but it's a little complicated, and user may need to pass the full config space to the forbidden object.
The implementation does not consider serialization (of sample_condition function).
Here is the implementation of ConditionedConfigurationSpace
:
from typing import List, Union, Callable
import numpy as np
from ConfigSpace import ConfigurationSpace, Configuration
from ConfigSpace.exceptions import ForbiddenValueError
class ConditionedConfigurationSpace(ConfigurationSpace):
"""
A configuration space that supports complex conditions between hyperparameters,
e.g., x1 <= x2 and x1 * x2 < 100.
User can define a sample_condition function to restrict the generation of configurations.
The following functions are guaranteed to return valid configurations:
- self.sample_configuration()
- get_one_exchange_neighbourhood() # may return empty list
Example
-------
>>> def sample_condition(config):
>>> # require x1 <= x2 and x1 * x2 < 100
>>> if config['x1'] > config['x2']:
>>> return False
>>> if config['x1'] * config['x2'] >= 100:
>>> return False
>>> return True
>>>
>>> cs = ConditionedConfigurationSpace()
>>> cs.add_hyperparameters([...])
>>> cs.set_sample_condition(sample_condition) # set the sample condition after all hyperparameters are added
>>> configs = cs.sample_configuration(1000)
Author: Jhj
"""
sample_condition: Callable[[Configuration], bool] = None
def set_sample_condition(self, sample_condition: Callable[[Configuration], bool]):
"""
The sample_condition function takes a configuration as input and returns a boolean value.
- If the return value is True, the configuration is valid and will be sampled.
- If the return value is False, the configuration is invalid and will be rejected.
This function should be called after all hyperparameters are added to the conditioned space.
"""
self.sample_condition = sample_condition
self._check_default_configuration()
def _check_forbidden(self, vector: np.ndarray) -> None:
"""
This function is called in Configuration.is_valid_configuration().
- When Configuration.__init__() is called with values (dict), is_valid_configuration() is called.
- When Configuration.__init__() is called with vectors (np.ndarray), there will be no validation check.
This function is also called in get_one_exchange_neighbourhood().
"""
# check original forbidden clauses first
super()._check_forbidden(vector)
if self.sample_condition is not None:
# Populating a configuration from an array does not check if it is a legal configuration.
# _check_forbidden() is not called. Otherwise, this would be stuck in an infinite loop.
config = Configuration(self, vector=vector)
if not self.sample_condition(config):
raise ForbiddenValueError('User-defined sample condition is not satisfied.')
def sample_configuration(self, size: int = 1) -> Union['Configuration', List['Configuration']]:
"""
In ConfigurationSpace.sample_configuration, configurations are built with vectors (np.ndarray),
so there will be no validation check and _check_forbidden() will not be called.
We need to check the sample condition manually.
Returns a single configuration if size = 1 else a list of Configurations
"""
if self.sample_condition is None:
return super().sample_configuration(size=size)
if not isinstance(size, int):
raise TypeError('Argument size must be of type int, but is %s'
% type(size))
elif size < 1:
return []
error_iteration = 0
accepted_configurations = [] # type: List['Configuration']
while len(accepted_configurations) < size:
missing = size - len(accepted_configurations)
if missing != size:
missing = int(1.1 * missing)
missing += 2
configurations = super().sample_configuration(size=missing) # missing > 1, return a list
configurations = [c for c in configurations if self.sample_condition(c)]
if len(configurations) > 0:
accepted_configurations.extend(configurations)
else:
error_iteration += 1
if error_iteration > 1000:
raise ForbiddenValueError("Cannot sample valid configuration for %s" % self)
if size <= 1:
return accepted_configurations[0]
else:
return accepted_configurations[:size]
def add_hyperparameter(self, *args, **kwargs):
if self.sample_condition is not None:
raise ValueError('Please add hyperparameter before setting sample condition.')
return super().add_hyperparameter(*args, **kwargs)
def add_hyperparameters(self, *args, **kwargs):
if self.sample_condition is not None:
raise ValueError('Please add hyperparameters before setting sample condition.')
return super().add_hyperparameters(*args, **kwargs)