Separating evolution from representation
Hello, First thank you for this amazing post. I tried to modify the model, to separate the evolution form the representation, meaning that I have a function that evolve the state and at the end a function that use the state evolved and compute a representation that compare it to the image to fit. (I also changed the living_channel to the first channel but this worked fine). However it seems that the gradient is not propagated to the weight of the evolution layer. Do you know why?
class CAModel(tf.keras.Model):
def __init__(self, channel_n=CHANNEL_N, fire_rate=CELL_FIRE_RATE):
super().__init__()
self.channel_n = channel_n
self.fire_rate = fire_rate
input_with_gradient = tf.keras.Input(shape=(None,None,self.channel_n*3),
name="gradient")
current_state = tf.keras.Input(shape=(None,None,self.channel_n),
name="current")
evolution = layers.Conv2D(self.channel_n, 1, activation=tf.nn.relu,
name="evolution")(input_with_gradient)
representation = layers.Conv2D(3, 1,
activation=tf.nn.relu,
name="representation")(current_state)
self.model = tf.keras.Model(inputs=[current_state,
input_with_gradient],
outputs=[evolution,representation], name="global")
self(tf.zeros([1, 3, 3, channel_n])) # dummy call to build the model
@tf.function
def perceive(self, x, angle=0.0):
identify = np.float32([0, 1, 0])
identify = np.outer(identify, identify)
dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0 # Sobel filter
dy = dx.T
c, s = tf.cos(angle), tf.sin(angle)
kernel = tf.stack([identify, c*dx-s*dy, s*dx+c*dy], -1)[:, :, None, :]
kernel = tf.repeat(kernel, self.channel_n, 2)
y = tf.nn.depthwise_conv2d(x, kernel, [1, 1, 1, 1], 'SAME')
return y
@tf.function
def call(self, x, fire_rate=None, angle=0.0, step_size=1.0):
pre_life_mask = get_living_mask(x)
y = self.perceive(x, angle)
dx,representation = self.model([x,y])
dx = dx*step_size
if fire_rate is None:
fire_rate = self.fire_rate
update_mask = tf.random.uniform(tf.shape(x[:, :, :, :1])) <= fire_rate
x += dx * tf.cast(update_mask, tf.float32)
post_life_mask = get_living_mask(x)
life_mask = pre_life_mask & post_life_mask
casted_life_mask = tf.cast(life_mask, tf.float32)
# a representation is an alpha channel at the top and
# rgb channel
return x * casted_life_mask , tf.concat([casted_life_mask,
representation * casted_life_mask],
axis=-1)
and for the evolution of the state:
for i in tf.range(iter_n):
x = ca(x)
x,representation = x
loss = tf.reduce_mean(loss_f(representation,img))
So I changed the representation layer to
representation = layers.Conv2D(3, 1,
activation=None,
kernel_initializer=tf.zeros_initializer,
name="representation")(current_state)
And checked that the gradient of the evolution layer is not null, but it seem that the loss decrease at the beginning and then plateau at a high value, where the image does not look at all like the target.