ConfigSpace icon indicating copy to clipboard operation
ConfigSpace copied to clipboard

An implementation to support complex conditions such as 'a < b and a * b < 10 and ...'

Open jhj0411jhj opened this issue 1 year ago • 3 comments

I implement a ConditionedConfigurationSpace that supports complex conditions between hyperparameters (e.g., x1 <= x2 and x1 * x2 < 100). User can define a sample_condition function to restrict the generation of configurations.

The following functions are guaranteed to return valid configurations:

  • self.sample_configuration()
  • get_one_exchange_neighbourhood() # may return empty list

Here is an example:

def sample_condition(config):
    # require x1 <= x2 and x1 * x2 < 100
    if config['x1'] > config['x2']:
        return False
    if config['x1'] * config['x2'] >= 100:
        return False
    return True
    # return config['x1'] <= config['x2'] and config['x1'] * config['x2'] < 100

cs = ConditionedConfigurationSpace()
cs.add_hyperparameters([...])
cs.set_sample_condition(sample_condition)  # set the sample condition after all hyperparameters are added
configs = cs.sample_configuration(1000)

Implementing this feature using fobiddens like ForbiddenClause or ForbiddenRelation might be a viable option, but it's a little complicated, and user may need to pass the full config space to the forbidden object.

The implementation does not consider serialization (of sample_condition function).

Here is the implementation of ConditionedConfigurationSpace:

from typing import List, Union, Callable
import numpy as np
from ConfigSpace import ConfigurationSpace, Configuration
from ConfigSpace.exceptions import ForbiddenValueError


class ConditionedConfigurationSpace(ConfigurationSpace):
    """
    A configuration space that supports complex conditions between hyperparameters,
        e.g., x1 <= x2 and x1 * x2 < 100.

    User can define a sample_condition function to restrict the generation of configurations.

    The following functions are guaranteed to return valid configurations:
        - self.sample_configuration()
        - get_one_exchange_neighbourhood()  # may return empty list

    Example
    -------

    >>> def sample_condition(config):
    >>>     # require x1 <= x2 and x1 * x2 < 100
    >>>     if config['x1'] > config['x2']:
    >>>         return False
    >>>     if config['x1'] * config['x2'] >= 100:
    >>>         return False
    >>>     return True
    >>>
    >>> cs = ConditionedConfigurationSpace()
    >>> cs.add_hyperparameters([...])
    >>> cs.set_sample_condition(sample_condition)  # set the sample condition after all hyperparameters are added
    >>> configs = cs.sample_configuration(1000)

    Author: Jhj

    """
    sample_condition: Callable[[Configuration], bool] = None

    def set_sample_condition(self, sample_condition: Callable[[Configuration], bool]):
        """
        The sample_condition function takes a configuration as input and returns a boolean value.
            - If the return value is True, the configuration is valid and will be sampled.
            - If the return value is False, the configuration is invalid and will be rejected.
        This function should be called after all hyperparameters are added to the conditioned space.
        """
        self.sample_condition = sample_condition
        self._check_default_configuration()

    def _check_forbidden(self, vector: np.ndarray) -> None:
        """
        This function is called in Configuration.is_valid_configuration().
            - When Configuration.__init__() is called with values (dict), is_valid_configuration() is called.
            - When Configuration.__init__() is called with vectors (np.ndarray), there will be no validation check.
        This function is also called in get_one_exchange_neighbourhood().
        """
        # check original forbidden clauses first
        super()._check_forbidden(vector)

        if self.sample_condition is not None:
            # Populating a configuration from an array does not check if it is a legal configuration.
            # _check_forbidden() is not called. Otherwise, this would be stuck in an infinite loop.
            config = Configuration(self, vector=vector)
            if not self.sample_condition(config):
                raise ForbiddenValueError('User-defined sample condition is not satisfied.')

    def sample_configuration(self, size: int = 1) -> Union['Configuration', List['Configuration']]:
        """
        In ConfigurationSpace.sample_configuration, configurations are built with vectors (np.ndarray),
            so there will be no validation check and _check_forbidden() will not be called.
            We need to check the sample condition manually.

        Returns a single configuration if size = 1 else a list of Configurations
        """
        if self.sample_condition is None:
            return super().sample_configuration(size=size)

        if not isinstance(size, int):
            raise TypeError('Argument size must be of type int, but is %s'
                            % type(size))
        elif size < 1:
            return []

        error_iteration = 0
        accepted_configurations = []  # type: List['Configuration']
        while len(accepted_configurations) < size:
            missing = size - len(accepted_configurations)

            if missing != size:
                missing = int(1.1 * missing)
            missing += 2

            configurations = super().sample_configuration(size=missing)  # missing > 1, return a list
            configurations = [c for c in configurations if self.sample_condition(c)]
            if len(configurations) > 0:
                accepted_configurations.extend(configurations)
            else:
                error_iteration += 1
                if error_iteration > 1000:
                    raise ForbiddenValueError("Cannot sample valid configuration for %s" % self)

        if size <= 1:
            return accepted_configurations[0]
        else:
            return accepted_configurations[:size]

    def add_hyperparameter(self, *args, **kwargs):
        if self.sample_condition is not None:
            raise ValueError('Please add hyperparameter before setting sample condition.')
        return super().add_hyperparameter(*args, **kwargs)

    def add_hyperparameters(self, *args, **kwargs):
        if self.sample_condition is not None:
            raise ValueError('Please add hyperparameters before setting sample condition.')
        return super().add_hyperparameters(*args, **kwargs)

jhj0411jhj avatar Nov 15 '22 11:11 jhj0411jhj