tensorflow-wavenet
tensorflow-wavenet copied to clipboard
How is the causal conv implementation causal?
Take this simple example:
import tensorflow as tf
import numpy as np
import wavenet
n_batches = 1
n_samples = 8
quantisation = 1
x = tf.placeholder(tf.float32, (n_batches, n_samples, quantisation))
x_sample = np.arange(0, n_batches * n_samples * quantisation).reshape((
n_batches, n_samples, quantisation))
dilation = 2
filter_size = 2
f = tf.placeholder(tf.float32, (filter_size, 1, 1))
f_sample = np.zeros(filter_size).reshape(filter_size, 1, 1)
f_sample[filter_size - 1] = 1
with tf.Session() as session:
result = session.run(
wavenet.causal_conv(x, f, dilation),
feed_dict={
x: x_sample,
f: f_sample
}
)
print(x_sample)
print(result)
We get,
[[[0]
[1]
[2]
[3]
[4]
[5]
[6]
[7]]]
[[[2.]
[3.]
[4.]
[5.]
[6.]
[7.]]]
That is we just chop of the first 2 elements. If you increase the dilation or filter size, you just cut more off.
To interpret this as causal, are we just saying that the output sequence element t' is actually at time t=t'+2?
If so, how it that convention enforced elsewhere in the model?
I see that to calculate loss, we cut off a length at the beginning of size equal to the receptive field from the input batch of encodings.
I see also that in _create_dilation_layer, we cut off some amount at the beginning of the output of the dilation layer to make it compatible with the output_width param, which is in turn set to the network input length minus the receptive field length, plus one.
Are these two bits all that is necessary to make this implementation of WaveNet causal?