stochastic_depth_keras icon indicating copy to clipboard operation
stochastic_depth_keras copied to clipboard

How the gate update , when training?

Open DehuaTang opened this issue 7 years ago • 4 comments

Thanks a lot for doing this. I might misunderstand something, but l can understand how the gate update when training. I can not understand the two code...

gate = K.variable(1, dtype="uint8")
add_tables += [{"death_rate": _death_rate, "gate": gate}]
return Lambda(lambda tensors: K.switch(gate, tensors[0], tensors[1]),
              output_shape=output_shape)([out, x])

Is this 'gate' always equal 1,when training ?...

class GatesUpdate(Callback):
    def on_batch_begin(self, batch, logs={}):
        open_all_gates()

        rands = np.random.uniform(size=len(add_tables))
        for t, rand in zip(add_tables, rands):
            if rand < K.get_value(t["death_rate"]):
                K.set_value(t["gate"], 0)

Is this 'GatesUpdate' make action on the 'Lambda' layer ,when training? Thank you.

DehuaTang avatar Oct 03 '17 02:10 DehuaTang

@DehuaTang

Sorry for messy code.

Is this 'gate' always equal 1,when training ?...

No. Its value will be changed in GatesUpdate by K.set_value(t["gate"], 0). The behavior of this lambda layer is decided on runtime (sess.run(~)) depending on this gate variable.

Is this 'GatesUpdate' make action on the 'Lambda' layer ,when training? Thank you.

Not Lambda layer but the variable (gate = K.variable(1, dtype="uint8")) it is using.

dblN avatar Oct 03 '17 04:10 dblN

Thank you for your reply. You are right, and I understood it. I want to add some code to check the change of variable 'gate' and layer 'Lambda' on runtime (sess.run(~)).

return Lambda(lambda tensors: K.switch(gate, tensors[0], tensors[1]),
              output_shape=output_shape)([out, x])

Can you give me some advice? Thank you.

DehuaTang avatar Oct 04 '17 07:10 DehuaTang

@DehuaTang Sorry for late reply. Do you mean you want to know the state of gates? Assuming you have programming skills, this example should be enough.

...
class GatesUpdate(Callback):
    def on_batch_begin(self, batch, logs={}):
        print("[ Batch %d ]" % batch)  # added
        open_all_gates()

        rands = np.random.uniform(size=len(add_tables))
        for i, (t, rand) in enumerate(zip(add_tables, rands)):
            if rand < K.get_value(t["death_rate"]):
                print("%d-th gate is closed", i)  # added
                K.set_value(t["gate"], 0)
            else:
                print("%d-th gate is open", i)  # added

    def on_batch_end(self, batch, logs={}):
        open_all_gates()  # for validation
...

I'm sorry but I do not have time to answer more of your questions. Thank you for your feedback.

dblN avatar Oct 05 '17 22:10 dblN

Thank you for helping me solve the problem ! ! Sorry for the late "star" . ★

DehuaTang avatar Oct 06 '17 02:10 DehuaTang