sonnet
sonnet copied to clipboard
Batchnorm handles empty batches incorrectly
Context: I have sparse example data with sub-structure (think of: message data { repeated float sub_0 = 0; repeated float sub_1 = 1; ...}
) where each sub-field is processed by a different branch of the network. The number of items in the sub-field varies (using reduction operations to obtain a fixed size vector [batch_size, embedding_dim]
) - especially the number of items can be 0 for one sub-field in which case the output shape would be [0, embedding_dim]
.
While all operations can correctly handle the case of having "size 0" in the batch axis, the operation that causes issues is the update of moving mean/variance (because tf.nn.moments that is used in build_batch_stats
returns NaNs).
I there should be a guard around this, such that no updates with NaN-values are made when the inputs are 0-sized.
Example to reproduce:
tf.reset_default_graph()
x = tf.placeholder(dtype=tf.float32, shape=[None, None, None])
y = tf.reduce_sum(x, axis=1, keepdims=False)
y = tf.reshape(y, [-1, 3])
y_bn = snt.BatchNormV2(scale=True, update_ops_collection=tf.GraphKeys.UPDATE_OPS)(y, True)
a = [
[[1,2,3],
[1,3,5]],
[[4,5,3],
[4,5,3]],
[[8,9,0],
[1,0,0]],
]
b = [[[]],[[]]]
inputs = (a, b, a)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for inp in inputs:
output = sess.run(fetches=(y_bn), feed_dict={x: inp})
sess.run(tf.get_collection(tf.GraphKeys.UPDATE_OPS), feed_dict={x: inp})
print("y_bn:", output)
output = sess.run({var.name: var.value() for var in tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)}, feed_dict={x: inp})
for name, var in output.items():
print(name, ":", var)