nni
nni copied to clipboard
Draft refactor plan of `nas.nn`
Cleanup nas.nn
Target: NNI 3.0. Involves API breaking change.
Why this refactor?
- The existing codebase is too messy. Mutators are randomly placed in the nn folder. It has become very hard to maintain them (not to mention implementing new mutators).
- For historic reasons, our IR is based on graph (graph is a must). However, in practice, when running experiments with Python / benchmark engine (which is actually the most common case), graph is not needed at all.
- Strategy wants to program directly to "space spec" for most cases, rather than "hack" the mutators and fool them with some
FixedSampler. - We don't have a clear interface to write new model spaces. All we have is a
@model_wrapper, which is not Pythonic, and whose behavior is not well defined. Models in space hub have some ad-hoc defined classmethods (e.g.,load_searched_model), but they are not formalized and there are many dup codes.
Philosophy
Let's review the responsibility of every component:
- Mutation Primitivies (high-level APIs)
- Interface to help users build their model space.
- Resides in
nas.nn.
- Mutator
- Extends
BaseMutator.
- Extends
- Engine
- Resides in
nas.execution.
- Resides in
- Strategy:
- Resides in
nas.strategy.
- Resides in
Upon initialiization:
sequenceDiagram
participant S as Strategy
participant E as Engine
participant M as Mutator
participant I as Mutation Primitives
S->>E: User-defined ModelSpace
E->>M: User-defined ModelSpace
M->>E: Base model + Mutators
Mutation Primitives Interface
class MutableMixin:
init_capture: bool = False # Detect ValueChoice in ``__init__`` arguments
def __new__(cls, *args, **kwargs):
# somewhere in the middle:
return self.call_hook('on_instantiate')
# This is used to replaced ``create_fixed_module``.
# Engine implement this hook to customize API's behavior.
# Relationship between a space and one arch:
def apply_arch(self: T, arch: Sample) -> T:
"""Return a new instance of this type with a frozen architecture."""
@classmethod
def freeze(cls: T, arch: Sample) -> T:
"""Return a version of this class type that will create a fixed architecture."""
def search_space_spec(self, recursive: bool = True) -> Schema: ...
@staticmethod
def next_label(self) -> str:
"""Retrieve the next auto-label."""
def to_graph(self) -> Graph:
"""TBD"""
class MutableModule(MutableMixin, nn.Module):
"""
1.
2. Graph engine knows
"""
def dry_run_forward(self, ...): ...
#class FixedCreationMixin:
# @classmethod
# def create_fixed_module(...): ...
class LayerChoice(MutableModule): ...
class InputChoice(MutableModule): ...
class ValueChoice(MutableModule): ...
...
class Conv2d(MutableModule, nn.Conv2d): ...
...
class Evaluator(MutableMixin):
Engine Implementation of APIs
# model_space.py
class BaseModelSpace(MutableMixin):
status: ModelStatus
def __init_subclass__(cls, ...): ...
"""Intercept calls to subclass's ``__init__``."""
def add_hyperparameter(self, values, type) -> ParameterSpec:
"""Previously known as ModelHyperparameterChoice"""
@classmethod
def preset(cls: Type[T], space_alias: str, ...) -> T: ...
"""For example: final_model = DARTS.frozen(some_arch).preset('darts-simple')"""
@classmethod
def searched_model(cls: Type[T], model_alias: str, ...) -> T:
"""Fixed arch + space preset"""
class FullGraphModelSpace(BaseModelSpace):
model_id: int
graphs: dict[str, Graph]
_root_graph_name: str = '_model'
history: list[Mutation]
class SimplifiedModelSpace(BaseModelSpace):
self.python_object: Any.| None = None # type is uncertain because it could differ between DL frameworks
self.python_class: Type | None = None
self.python_init_params: dict[str, Any] = None
class SimplifiedModel(Model):
search_space_spec:
self.model_id: int = uid('model')
self.python_object: Optional[Any] = None # type is uncertain because it could differ between DL frameworks
self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None
self.status: ModelStatus = ModelStatus.Mutating
self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {}
self.evaluator: Optional[Evaluator] = None
self.history: List['Mutation'] = []
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []
Mutation on IRs
class SomeMutator(Mutator):
def mutate(self, model):
...
For most of the strategies, what they get from model space is a schema dict, they only need to sample from the schema dict, and apply the sample (as a fixed arch) onto the model space.
# distribution.py
class Distribution:
...
class DiscreteDist:
prior:
class DistributionSchedule:
def sample(self):
...
self.step()
# schema.py
class Schema:
label: str
def sample(self, memo: Sample) -> Sample: ...
class Discrete(Schema):
choices: list[Choice]
dist: Distribution | None
def sample(self, memo): return random.choice(self.choices)
class Continuous(Schema):
low: float
high: float
dist: Distribution | None
class SchemaList(Schema, list):
schemas: list[Schema]
dist: Distribution | None
class SchemaDict(Schema, dict):
schemas: dict[str, Schema]
dist: Distribution | None
class Condition(Schema):
operator: Operator
operand1: str
operand2: str | None
def sample(self, memo):
if not operator(memo[operand1], memo[operand2]):
raise SamplingError(...)
class Validation(Schema):
validate: ([dict] -> bool)
def sample(self, memo):
if not self.validate(memo):
raise SamplingError(...)
# nas-specific
class Cell(Schema):
...
class Mutator(Schema):
sampler: Sampler
def bind_sampler(sampler: Sampler) -> Mutator: ...
def sample(self, memo): ...
Tensorflow support
- Framework agnostic:
- MutableMixin
- ModelSpace
- Framework specific:
- MutableModule (pytorch.MutableModule, tensorflow.MutableModule)
- LayerChoice
- ...
Legacy code
For reference.
class LayerChoiceMutator(Mutator):
def __init__(self, nodes: List[Node]):
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
def mutate(self, model):
candidates = self.nodes[0].operation.parameters['candidates']
chosen = self.choice(candidates)
for node in self.nodes:
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[cast(Cell, node.operation).cell_name]
chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None))
operation = cast(Cell, node.operation)
target_node = cast(Node, model.get_node_by_name(node.name))
target_node.update_operation(Cell(operation.cell_name))
# remove redundant nodes
for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues
if rm_node.name != chosen_node.name:
rm_node.remove()