nni icon indicating copy to clipboard operation
nni copied to clipboard

Draft refactor plan of `nas.nn`

Open matluster opened this issue 3 years ago • 0 comments

Cleanup nas.nn

Target: NNI 3.0. Involves API breaking change.

Why this refactor?

  • The existing codebase is too messy. Mutators are randomly placed in the nn folder. It has become very hard to maintain them (not to mention implementing new mutators).
  • For historic reasons, our IR is based on graph (graph is a must). However, in practice, when running experiments with Python / benchmark engine (which is actually the most common case), graph is not needed at all.
  • Strategy wants to program directly to "space spec" for most cases, rather than "hack" the mutators and fool them with some FixedSampler.
  • We don't have a clear interface to write new model spaces. All we have is a @model_wrapper, which is not Pythonic, and whose behavior is not well defined. Models in space hub have some ad-hoc defined classmethods (e.g., load_searched_model), but they are not formalized and there are many dup codes.

Philosophy

Let's review the responsibility of every component:

  • Mutation Primitivies (high-level APIs)
    • Interface to help users build their model space.
    • Resides in nas.nn.
  • Mutator
    • Extends BaseMutator.
  • Engine
    • Resides in nas.execution.
  • Strategy:
    • Resides in nas.strategy.

Upon initialiization:

sequenceDiagram
    participant S as Strategy
    participant E as Engine
    participant M as Mutator
    participant I as Mutation Primitives
    S->>E: User-defined ModelSpace
    E->>M: User-defined ModelSpace
    M->>E: Base model + Mutators

Mutation Primitives Interface

class MutableMixin:
	init_capture: bool = False  # Detect ValueChoice in ``__init__`` arguments

	def __new__(cls, *args, **kwargs):
		# somewhere in the middle:
		return self.call_hook('on_instantiate')
		# This is used to replaced ``create_fixed_module``.
		# Engine implement this hook to customize API's behavior.

	# Relationship between a space and one arch:
	def apply_arch(self: T, arch: Sample) -> T:
		"""Return a new instance of this type with a frozen architecture."""

	@classmethod
	def freeze(cls: T, arch: Sample) -> T:
		"""Return a version of this class type that will create a fixed architecture."""

	def search_space_spec(self, recursive: bool = True) -> Schema: ...

	@staticmethod
	def next_label(self) -> str:
		"""Retrieve the next auto-label."""

	def to_graph(self) -> Graph:
		"""TBD"""

class MutableModule(MutableMixin, nn.Module):
	"""
	1. 
	2. Graph engine knows 
	"""
	def dry_run_forward(self, ...): ...

#class FixedCreationMixin:
#	@classmethod
#	def create_fixed_module(...): ...


class LayerChoice(MutableModule): ...

class InputChoice(MutableModule): ...

class ValueChoice(MutableModule): ...

...

class Conv2d(MutableModule, nn.Conv2d): ...

...

class Evaluator(MutableMixin):

Engine Implementation of APIs

# model_space.py
class BaseModelSpace(MutableMixin):
	status: ModelStatus

	def __init_subclass__(cls, ...): ...
		"""Intercept calls to subclass's ``__init__``."""

	def add_hyperparameter(self, values, type) -> ParameterSpec:
		"""Previously known as ModelHyperparameterChoice"""

	@classmethod
	def preset(cls: Type[T], space_alias: str, ...) -> T: ...
		"""For example: final_model = DARTS.frozen(some_arch).preset('darts-simple')"""

	@classmethod
	def searched_model(cls: Type[T], model_alias: str, ...) -> T:
		"""Fixed arch + space preset"""

class FullGraphModelSpace(BaseModelSpace):
	model_id: int
	graphs: dict[str, Graph]
	_root_graph_name: str = '_model'
	history: list[Mutation]

class SimplifiedModelSpace(BaseModelSpace):
    self.python_object: Any.| None = None  # type is uncertain because it could differ between DL frameworks
    self.python_class: Type | None = None
    self.python_init_params: dict[str, Any] = None

class SimplifiedModel(Model):
	search_space_spec: 

    self.model_id: int = uid('model')
    self.python_object: Optional[Any] = None  # type is uncertain because it could differ between DL frameworks
    self.python_class: Optional[Type] = None
    self.python_init_params: Optional[Dict[str, Any]] = None

    self.status: ModelStatus = ModelStatus.Mutating

    self._root_graph_name: str = '_model'
    self.graphs: Dict[str, Graph] = {}
    self.evaluator: Optional[Evaluator] = None

    self.history: List['Mutation'] = []

    self.metric: Optional[MetricData] = None
    self.intermediate_metrics: List[MetricData] = []

Mutation on IRs

class SomeMutator(Mutator):
	def mutate(self, model):
		...

For most of the strategies, what they get from model space is a schema dict, they only need to sample from the schema dict, and apply the sample (as a fixed arch) onto the model space.

# distribution.py
class Distribution:
	...

class DiscreteDist:
	prior: 

class DistributionSchedule:
	def sample(self):
		...
		self.step()

# schema.py
class Schema:
	label: str
	def sample(self, memo: Sample) -> Sample: ...

class Discrete(Schema):
	choices: list[Choice]
	dist: Distribution | None
	def sample(self, memo): return random.choice(self.choices)

class Continuous(Schema):
	low: float
	high: float
	dist: Distribution | None

class SchemaList(Schema, list):
	schemas: list[Schema]
	dist: Distribution | None

class SchemaDict(Schema, dict):
	schemas: dict[str, Schema]
	dist: Distribution | None

class Condition(Schema):
	operator: Operator
	operand1: str
	operand2: str | None

	def sample(self, memo):
		if not operator(memo[operand1], memo[operand2]):
			raise SamplingError(...)

class Validation(Schema):
	validate: ([dict] -> bool)
	def sample(self, memo):
		if not self.validate(memo):
			raise SamplingError(...)

# nas-specific
class Cell(Schema):
	...

class Mutator(Schema):
	sampler: Sampler
	def bind_sampler(sampler: Sampler) -> Mutator: ...
	def sample(self, memo): ...

Tensorflow support

  • Framework agnostic:
    • MutableMixin
    • ModelSpace
  • Framework specific:
    • MutableModule (pytorch.MutableModule, tensorflow.MutableModule)
    • LayerChoice
    • ...

Legacy code

For reference.



class LayerChoiceMutator(Mutator):
    def __init__(self, nodes: List[Node]):
        super().__init__(label=nodes[0].operation.parameters['label'])
        self.nodes = nodes

    def mutate(self, model):
        candidates = self.nodes[0].operation.parameters['candidates']
        chosen = self.choice(candidates)
        for node in self.nodes:
            # Each layer choice corresponds to a cell, which is unconnected in the base graph.
            # We add the connections here in the mutation logic.
            # Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
            target = model.graphs[cast(Cell, node.operation).cell_name]
            chosen_node = target.get_node_by_name(chosen)
            assert chosen_node is not None
            target.add_edge((target.input_node, 0), (chosen_node, None))
            target.add_edge((chosen_node, None), (target.output_node, None))
            operation = cast(Cell, node.operation)
            target_node = cast(Node, model.get_node_by_name(node.name))
            target_node.update_operation(Cell(operation.cell_name))

            # remove redundant nodes
            for rm_node in list(target.hidden_nodes):  # remove from a list on the fly will cause issues
                if rm_node.name != chosen_node.name:
                    rm_node.remove()

matluster avatar May 09 '22 02:05 matluster