AutoDL-Projects icon indicating copy to clipboard operation
AutoDL-Projects copied to clipboard

Questions about DARTS

Open buttercutter opened this issue 3 years ago • 52 comments

  1. For DARTS complexity analysis, anyone have any idea how to derive the (k+1)*k/2 expression ? Why 2 input nodes ? How will the calculated value change if graph isomorphism is considered ? Why "2+3+4+5" learnable edges ? If there is lack of connection, the paper should not add 1 which does not actually contribute to learnable edges configurations at all ?

  2. Why need to train the weights for normal cells and reduction cells separately as shown in Figures 4 and 5 below ?

  3. How to arrange the nodes such that the NAS search will actually converge with minimum error ? Note: Not all nodes are connected to each and every other nodes

  4. Why is GDAS 10 times faster than DARTS ?

DARTS_normal_reduction_cells

DARTS_complexity_analysis

buttercutter avatar May 01 '21 04:05 buttercutter

Thanks for pointing out these questions.

(1). (k+1)k/2 is because for the k-th node, you have (k+1) preceding nodes. Selecting two from them has C(K+1, 2) possibilities. 2 input nodes are pre-defined according to human expert's experience. If isomorphism is considered, you need another way to represent this DAG. Before pruning the fully-connected graph into "2-input-nodes version", each node has (k+1) preceding nodes and has (k+1) edges -> (1+1) + (2+1) + (3+1) + (4+1) = 14 learable edges.

(2). We hypothesis the normal cell and reduction cell will have a very different topology structure

(3). No theoretical guarantee.

(4). Because for each iteration, DARTS needs to weighted-sum the architecture parameters and the outputs of every candidate operation -> O(N), but GDAS only needs to "sample" one candidate operation -> O(1).

D-X-Y avatar May 08 '21 03:05 D-X-Y

Why for the k-th node, you have (k+1) preceding nodes. ?

buttercutter avatar May 09 '21 04:05 buttercutter

Because for each cell, they also allow the output of two previous cells as inputs, so for the 1-th first node in a cell, its preceding nodes are [last-cell-outputs, second-last-cell-outputs]. For the second node, it is: [last-cell-outputs, second-last-cell-outputs, first-node-outputs]

D-X-Y avatar May 09 '21 06:05 D-X-Y

For second node, what is the difference between last-cell-outputs and first-node-outputs ?

image

Solution:

Each intermediate state, 0-3, is connected to each previous intermediate state as well as 
the output of the previous two cells, c_{k-2} and c_{k-1} (after a preprocessing layer).

image

buttercutter avatar May 09 '21 12:05 buttercutter

The last-cell-outputs is the output of green box c_{k-1}. The first-node-outputs is the output of blue box 0.

D-X-Y avatar May 09 '21 15:05 D-X-Y

if you add gumbel distributed noise to logits and take the argmax, the gumbel noise is the exact right distribution that it is the same as softmaxing the logits and sampling from the discrete distribution defined by those probabilities

Someone told me that the above, but I am not familiar with gumbel and how it actually helps to speed up GDAS with respect to DARTS. I suppose it is the gumbel-max trick mentioned in the paper. I do not quite understand expressions (3) and (5) in the GDAS paper.

buttercutter avatar May 10 '21 15:05 buttercutter

You could have a look at our code: https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_searchs/search_model_gdas.py#L89

D-X-Y avatar May 10 '21 16:05 D-X-Y

@D-X-Y Could you comment on this reply on your Gumbel-Max code implementation ?

@Unity05 was suggesting to use softargmax

buttercutter avatar May 16 '21 01:05 buttercutter

Hi, I was just explaining that the temperature you're using uses the same basic idea as softargmax.

Unity05 avatar May 16 '21 09:05 Unity05

@D-X-Y

in your coding, would you be able to describe how the logic of hardwts = one_h - probs.detach() + probs is used in the forward search function feature = cell.forward_gdas(feature, hardwts, index) ?

I mean the computation logic for hardwts is a bit weird or strange.
Why hardwts need to make use of both one_h and probs ? Why one of the probs need detach() ?

Besides, why would gumbel-max computation need a while loop ? I suppose you are using Gumbel(0, 1) ?

How exactly gumbel-max transforms equation (3) into equation (5) ?

buttercutter avatar May 16 '21 12:05 buttercutter

For the question on hardwts , see the note section inside https://pytorch.org/docs/stable/nn.functional.html#gumbel-softmax

The main trick for hard is to do y_hard - y_soft.detach() + y_soft

It achieves two things: 
- makes the output value exactly one-hot (since we add then subtract y_soft value) 
- makes the gradient equal to y_soft gradient (since we strip all other gradients)

@D-X-Y by the way, why PNASNet mention Note that we learn a single cell type instead of distinguishing between Normal and Reduction cell. ?

image

buttercutter avatar May 18 '21 17:05 buttercutter

Solution:

we do not distinguish between Normal and Reduction cells, 
but instead emulate a Reduction cell by using a Normal cell with stride 2

So, in this case, I suppose I could use only single type of weights for both normal cells and reduction cells ?

As for algorithm 1, how is A different from W ? Note: The corresponding notation meaning explanation after equations (3) and (4) of the paper is very confusing to me.

image

buttercutter avatar May 21 '21 01:05 buttercutter

@D-X-Y

in your coding, would you be able to describe how the logic of hardwts = one_h - probs.detach() + probs is used in the forward search function feature = cell.forward_gdas(feature, hardwts, index) ?

I mean the computation logic for hardwts is a bit weird or strange. Why hardwts need to make use of both one_h and probs ? Why one of the probs need detach() ?

Besides, why would gumbel-max computation need a while loop ? I suppose you are using Gumbel(0, 1) ?

How exactly gumbel-max transforms equation (3) into equation (5) ?

Sorry for the late reply, I'm a little bit busy these days.

hardwts = one_h - probs.detach() + probs aims to make hardwts has the same gradients as probs yet still keeps the one-hot values -- one_h. The while loop is a trick added by myself, which is to avoid very rare cases of NAN

D-X-Y avatar May 21 '21 08:05 D-X-Y

For the question on hardwts , see the note section inside https://pytorch.org/docs/stable/nn.functional.html#gumbel-softmax

The main trick for hard is to do y_hard - y_soft.detach() + y_soft

It achieves two things: 
- makes the output value exactly one-hot (since we add then subtract y_soft value) 
- makes the gradient equal to y_soft gradient (since we strip all other gradients)

@D-X-Y by the way, why PNASNet mention Note that we learn a single cell type instead of distinguishing between Normal and Reduction cell. ?

image

Yes, I borrow the idea of how to implement gumbel from PyTorch with a few modifications.

For PNAS, you may need to email their authors for the detailed reasons.

D-X-Y avatar May 21 '21 08:05 D-X-Y

Solution:

we do not distinguish between Normal and Reduction cells, 
but instead emulate a Reduction cell by using a Normal cell with stride 2

So, in this case, I suppose I could use only single type of weights for both normal cells and reduction cells ?

As for algorithm 1, how is A different from W ? Note: The corresponding notation meaning explanation after equations (3) and (4) of the paper is very confusing to me.

image

Yes, in this case, the architecture weights for normal cells and reduction cells are shared. A is the architecture weights -- the logits assigned for each candidate operation. W is the weights of the supernet -- the weights for convolution layers, etc.

D-X-Y avatar May 21 '21 08:05 D-X-Y

@D-X-Y I am bit confused with the difference between cell and node

Edit: I think I got it now. A single cell contains 4 distinct nodes

By the way, in Algorithm 1, why GDAS updates W before A ?

buttercutter avatar May 22 '21 09:05 buttercutter

@D-X-Y I am bit confused with the difference between cell and node

Edit: I think I got it now. A single cell contains 4 distinct nodes

By the way, in Algorithm 1, why GDAS updates W before A ?

I feel it does not matter? Updating W, A, W, A, W, A or A, W, A, W, A, W would not make a big difference?

D-X-Y avatar May 24 '21 03:05 D-X-Y

For GDAS, would https://networkx.org/documentation/stable/tutorial.html#multigraphs be suitable for both forward inference and backward propagation ?

buttercutter avatar May 26 '21 02:05 buttercutter

I'm not familiar with networkx and can not comment on that.

D-X-Y avatar May 26 '21 02:05 D-X-Y

@D-X-Y I am confused as in how https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/models/cell_searchs/search_model_gdas.py implemented multiple parallel connections between nodes

buttercutter avatar May 26 '21 04:05 buttercutter

@D-X-Y I am confused as in how equation (7) is an approximation of equation (5) as described in gdas paper ?

buttercutter avatar Jun 01 '21 02:06 buttercutter

@promach The difference between $h$ in Eq.(5) and Eq.(7) is that:

  • In Eq.(5), it is a one-hot vector.
  • In Eq.(7), it is a soft-probability vector with Gumbel noise.

As you run Eq.(5) infinite times, and run Eq.(7) infinite times, their average results should be very close.

D-X-Y avatar Jun 01 '21 03:06 D-X-Y

@D-X-Y in normal backpropagation, there is only a single edge in between two nodes.

However in GDAS, there are multiple parallel edges in between two nodes.

So, how to perform backpropagation for GDAS or more generally, Network Architecture Search (NAS) ?

image

buttercutter avatar Jun 01 '21 08:06 buttercutter

For https://github.com/D-X-Y/AutoDL-Projects/issues/99#issuecomment-845789377 , how do you actually update both W and A simultaneously in a single epoch ?

Could you point me to the relevant code for the update portion ?
Did you use two def forward() functions for W and A since two disjoint sets are used ?

GDAS algorithm

buttercutter avatar Jun 08 '21 12:06 buttercutter

@promach , at a single iteration, we will first update W and then update A. Please see the codes here: https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NAS-Bench-201-algos/GDAS.py#L49

D-X-Y avatar Jun 09 '21 03:06 D-X-Y

If update W first, then only update A , the question is should I train the convolution kernel weights W based on the trained best edges for A ?

buttercutter avatar Jun 10 '21 02:06 buttercutter

I feel it does not matter for the order of W and A. As if you look at multiple iterations, it will be W -> A -> W -> A ->W -> A -> W -> A -> W -> A -> W -> A .... Whether the first one is W or A would not make a big difference.

D-X-Y avatar Jun 10 '21 03:06 D-X-Y

the issue lingering in my head is that if W is to optimized FIRST, should W be trained under which exact A result ?

buttercutter avatar Jun 10 '21 05:06 buttercutter

What do you mean by exact A?

D-X-Y avatar Jun 10 '21 05:06 D-X-Y

If W is trained FIRST, whichA should W training process uses as W's architecture ?

buttercutter avatar Jun 10 '21 05:06 buttercutter

A is a set of variables indicating the architecture encoding. There only one A and no other options?

D-X-Y avatar Jun 10 '21 05:06 D-X-Y

You could have a look at the codes here and Would you mind clarifying what do you think the codes should be?

D-X-Y avatar Jun 10 '21 05:06 D-X-Y

Let me rephrase my question, how do you define base_inputs and arch_inputs ?

It seems to be different from how DARTS paper originally proposed. See equations (5) and (6) of DARTS paper

image

buttercutter avatar Jun 10 '21 06:06 buttercutter

base_inputs are a batch of samples from the training data, arch_inputs are a batch of samples from the validation data.

Yes, following the DARTS paper, I should switch the order of updating W and A.

D-X-Y avatar Jun 10 '21 07:06 D-X-Y

during training for W, should I use a particular found architecture inside that particular epoch ? OR should I use the whole supernet ?

buttercutter avatar Jun 10 '21 08:06 buttercutter

It depends on the NAS algorithm. For DARTS, they use the whole supernet. For GDAS, we use an architecture candidate randomly sampled based on A.

D-X-Y avatar Jun 10 '21 09:06 D-X-Y

For GDAS, we use an architecture candidate randomly sampled based on A.

The candidate is chosen using gumbel-argmax (equation (5) and (6) of GDAS paper) , instead of chosen randomly. Please correct me if wrong.

buttercutter avatar Jun 10 '21 09:06 buttercutter

gumbel-argmax is a kind of random? because the $o_{k}$ is randomly sampled from Gumbel(0, 1).

D-X-Y avatar Jun 10 '21 10:06 D-X-Y

For https://github.com/D-X-Y/AutoDL-Projects/issues/99#issuecomment-835802887 , there are two types of outputs from the blue node.

One type of (multiple edges) output connects to the input of the other blue nodes ?

Another type of (single edge) output connects directly to the yellow node ?

buttercutter avatar Jun 25 '21 12:06 buttercutter

It seems that both ENAS and PNAS just perform add and concat operations for the connection to the output node

image

image

buttercutter avatar Jun 27 '21 04:06 buttercutter

@D-X-Y I implemented a draft code on GDAS,

However, could you advise whether this edge weight training epoch mechanism will actually work for GDAS ?

buttercutter avatar Jun 29 '21 04:06 buttercutter

For #99 (comment) , there are two types of outputs from the blue node.

One type of (multiple edges) output connects to the input of the other blue nodes ?

Another type of (single edge) output connects directly to the yellow node ?

Yes, you are right~

D-X-Y avatar Jun 29 '21 09:06 D-X-Y

It seems that both ENAS and PNAS just perform add and concat operations for the connection to the output node

@promach Yes. DARTS also uses add for the intermediate nodes and concat for the final output node (https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/models/cell_searchs/search_cells.py#L251).

D-X-Y avatar Jun 29 '21 09:06 D-X-Y

@D-X-Y I implemented a draft code on GDAS,

However, could you advise whether this edge weight training epoch mechanism will actually work for GDAS ?

I personally feel the implementations are incorrect. I havn't fully checked the codes, but at least, the input for every cell/node should not be the same forward_edge(train_inputs).

D-X-Y avatar Jun 29 '21 09:06 D-X-Y

How to code the forward pass function correctly for edge weight training ?

    # self-defined initial NAS architecture, for supernet architecture edge weight training
    def forward_edge(self, x):
        self.__freeze_f()
        self.__unfreeeze_w()

        return self.weights

Note: This is for training step 2 inside Algorithm 1 of DARTS paper

buttercutter avatar Jun 30 '21 11:06 buttercutter

why do we return self.weights? Instead of return the value of using weights on x? The logics of freeze and unfreeze are correct, but I do not understand return ...

D-X-Y avatar Jul 02 '21 09:07 D-X-Y

I am not sure how to train edge weights, hence the question about def forward_edge()

~~Besides, I also suspect the forward pass function for architecture weight (step 1 inside DARTS Algorithm 1) might be incorrect as well because it only trains the neural network function's internal weight parameters instead of architecture weight.~~

Note: self.f(x) is something like nn.Linear() , nn.Conv2d

    # for NN functions internal weights training
    def forward_f(self, x):
        self.__unfreeze_f()
        self.__freeeze_w()

        # inheritance in python classes and SOLID principles
        # https://en.wikipedia.org/wiki/SOLID
        # https://blog.cleancoder.com/uncle-bob/2020/10/18/Solid-Relevance.html
        return self.f(x)

buttercutter avatar Jul 02 '21 09:07 buttercutter

Sorry, I misinterpreted the purpose of the two forward pass functions.

forward_edge() is for architecture weights (step 1), while forward_f() is for NN function's internal weights (step 2).

However, I am still not sure how to code for def forward_edge(self, x)

buttercutter avatar Jul 02 '21 09:07 buttercutter

@D-X-Y For ordinary NN training operation, we have some feature maps outputs.

However for the edge weights (NAS) training operation, there are no feature maps outputs though. So, what should be fed into x for forward_edge(x) ?

buttercutter avatar Jul 09 '21 06:07 buttercutter

Is using nn.Linear() to train edge weights feasible for GDAS on a small GPU ?

    # self-defined initial NAS architecture, for supernet architecture edge weight training
    def forward_edge(self, x):
        self.__freeze_f()
        self.__unfreeeze_w()

        return self.linear(x)

buttercutter avatar Jul 12 '21 11:07 buttercutter

@D-X-Y I managed to get my own GDAS code implementation up and running.

However, the loss stay the same which indicates the training process is still incorrect.

Could you advise ?

image

buttercutter avatar Jul 26 '21 15:07 buttercutter

@D-X-Y looking at output of graph.named_parameters() , some of the internal connections within the super-net architecture are still not connected properly. Any comments ?

buttercutter avatar Nov 12 '21 02:11 buttercutter