stochastic_depth_keras
stochastic_depth_keras copied to clipboard
How the gate update , when training?
Thanks a lot for doing this. I might misunderstand something, but l can understand how the gate update when training. I can not understand the two code...
gate = K.variable(1, dtype="uint8")
add_tables += [{"death_rate": _death_rate, "gate": gate}]
return Lambda(lambda tensors: K.switch(gate, tensors[0], tensors[1]),
output_shape=output_shape)([out, x])
Is this 'gate' always equal 1,when training ?...
class GatesUpdate(Callback):
def on_batch_begin(self, batch, logs={}):
open_all_gates()
rands = np.random.uniform(size=len(add_tables))
for t, rand in zip(add_tables, rands):
if rand < K.get_value(t["death_rate"]):
K.set_value(t["gate"], 0)
Is this 'GatesUpdate' make action on the 'Lambda' layer ,when training? Thank you.
@DehuaTang
Sorry for messy code.
Is this 'gate' always equal 1,when training ?...
No. Its value will be changed in GatesUpdate
by K.set_value(t["gate"], 0)
. The behavior of this lambda layer is decided on runtime (sess.run(~)
) depending on this gate variable.
Is this 'GatesUpdate' make action on the 'Lambda' layer ,when training? Thank you.
Not Lambda
layer but the variable (gate = K.variable(1, dtype="uint8")
) it is using.
Thank you for your reply. You are right, and I understood it. I want to add some code to check the change of variable 'gate' and layer 'Lambda' on runtime (sess.run(~)).
return Lambda(lambda tensors: K.switch(gate, tensors[0], tensors[1]),
output_shape=output_shape)([out, x])
Can you give me some advice? Thank you.
@DehuaTang Sorry for late reply. Do you mean you want to know the state of gates? Assuming you have programming skills, this example should be enough.
...
class GatesUpdate(Callback):
def on_batch_begin(self, batch, logs={}):
print("[ Batch %d ]" % batch) # added
open_all_gates()
rands = np.random.uniform(size=len(add_tables))
for i, (t, rand) in enumerate(zip(add_tables, rands)):
if rand < K.get_value(t["death_rate"]):
print("%d-th gate is closed", i) # added
K.set_value(t["gate"], 0)
else:
print("%d-th gate is open", i) # added
def on_batch_end(self, batch, logs={}):
open_all_gates() # for validation
...
I'm sorry but I do not have time to answer more of your questions. Thank you for your feedback.
Thank you for helping me solve the problem ! ! Sorry for the late "star" . ★