returnn
returnn copied to clipboard
tf.data pipeline does currently not feed `batch_dim`
For the tf.data
pipeline (#292) using:
def dataset_pipeline(context):
from returnn.tf.compat import v1 as tf
dataset = context.get_returnn_dataset()
dataset = dataset.padded_batch(
batch_size=12,
padded_shapes=tf.data.get_output_shapes(dataset),
drop_remainder=False,
)
dataset = context.map_producer_to_consumer(dataset)
dataset = context.prefetch_to_consumer_device(dataset)
return dataset
I get the following exception:
EXCEPTION
Traceback (most recent call last):
File "/work/asr4/rossenbach/sisyphus_work_folders/tts_asr_2021_work/i6_core/tools/git/CloneGitRepositoryJob.3OlznTemDM9T/output/repository/returnn/tf/engine.py", line 689, in Runner.run
line: fetches_results = sess.run(
fetches_dict, feed_dict=feed_dict) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]]
locals:
fetches_results = <not found>
sess = <local> <tensorflow.python.client.session.Session object at 0x14cae45105e0>
sess.run = <local> <bound method BaseSession.run of <tensorflow.python.client.session.Session object at 0x14cae45105e0>>
fetches_dict = <local> {'size:audio_features:0': <tf.Tensor 'IteratorGetNext:2' shape=(?,) dtype=int32>, 'size:bpe_labels:0': <tf.Tensor 'IteratorGetNext:3' shape=(?,) dtype=int32>, 'loss': <tf.Tensor 'objective/loss/add:0' shape=() dtype=float32>, 'cost:ctc': <tf.Tensor 'objective/loss/loss/loss_ctc/Sum:0' shape=() dt..., len = 15
File "/work/tools/asr/python/3.8.0_tf_2.3.4-haswell+cuda10.1+mkl/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 957, in BaseSession.run
line: result = self._run(None, fetches, feed_dict, options_ptr,
run_metadata_ptr)
locals:
result = <not found>
self = <local> <tensorflow.python.client.session.Session object at 0x14cae45105e0>
self._run = <local> <bound method BaseSession._run of <tensorflow.python.client.session.Session object at 0x14cae45105e0>>
run_metadata_ptr = <local> None
File "/work/tools/asr/python/3.8.0_tf_2.3.4-haswell+cuda10.1+mkl/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1180, in BaseSession._run
line: results = self._do_run(handle, final_targets, final_fetches,
feed_dict_tensor, options, run_metadata)
locals:
results = <not found>
self = <local> <tensorflow.python.client.session.Session object at 0x14cae45105e0>
self._do_run = <local> <bound method BaseSession._do_run of <tensorflow.python.client.session.Session object at 0x14cae45105e0>>
handle = <local> None
final_targets = <local> [<tf.Operation 'conformer_block_01_conv_mod_bn/batch_norm/cond/Merge_1' type=Merge>, <tf.Operation 'conformer_block_02_conv_mod_bn/batch_norm/cond/Merge_1' type=Merge>, <tf.Operation 'optim_and_step_incr' type=NoOp>]Tensor 'IteratorGetNext:2' shape=(?,) dtype=int32>, <tf.Tensor 'IteratorGetNext:3' shape=(?,) dtype=int32>, <tf.Tensor 'objective/loss/add:0' shape=() dtype=float32>, <tf.Tensor 'objective/loss/loss/loss_ctc/Sum:0' shape=() dtype=float32>, <tf.Tensor 'objective/loss/error/loss_ctc_error/Sum:..., len = 13
options = <local> None
run_metadata = <local> None
File "/work/tools/asr/python/3.8.0_tf_2.3.4-haswell+cuda10.1+mkl/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1358, in BaseSession._do_run
line: return self._do_call(_run_fn, feeds, fetches, targets, options,
run_metadata)
locals:
self = <local> <tensorflow.python.client.session.Session object at 0x14cae45105e0>
self._do_call = <local> <bound method BaseSession._do_call of <tensorflow.python.client.session.Session object at 0x14cae45105e0>>
_run_fn = <local> <function BaseSession._do_run.<locals>._run_fn at 0x14c6abacef70>
feeds = <local> {<tensorflow.python._pywrap_tf_session.TF_Output object at 0x14c99f3efeb0>: array(True)}
fetches = <local> [<tensorflow.python._pywrap_tf_session.TF_Output object at 0x14c99f62f870>, <tensorflow.python._pywrap_tf_session.TF_Output object at 0x14c99f62fd30>, <tensorflow.python._pywrap_tf_session.TF_Output object at 0x14c99e75bdf0>, <tensorflow.python._pywrap_tf_session.TF_Output object at 0x14c99ede317..., len = 13pywrap_tf_session.TF_Operation object at 0x14c99f1e47b0>, <tensorflow.python._pywrap_tf_session.TF_Operation object at 0x14c6abd86170>]
run_metadata = <local> None
File "/work/tools/asr/python/3.8.0_tf_2.3.4-haswell+cuda10.1+mkl/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1384, in BaseSession._do_call
line: raise type(e)(node_def, op, message)
locals:
type = <builtin> <class 'type'>
e = <not found>
node_def = <local> name: "extern_data/placeholders/batch_dim"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
}
value {
shape {
}
}
}
op = <local> <tf.Operation 'extern_data/placeholders/batch_dim' type=Placeholder>
message = <local> "2 root error(s) found.\n (0) Invalid argument: You must feed a value for placeholder tensor 'extern_data/placeholders/batch_dim' with dtype int32\n\t [[node extern_data/placeholders/batch_dim (defined at work/asr4/rossenbach/sisyphus_work_folders/tts_asr_2021_work/i6_core/tools/git/CloneGitRepo..., len = 877
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: You must feed a value for placeholder tensor 'extern_data/placeholders/batch_dim' with dtype int32
[[node extern_data/placeholders/batch_dim (defined at work/asr4/rossenbach/sisyphus_work_folders/tts_asr_2021_work/i6_core/tools/git/CloneGitRepositoryJob.3OlznTemDM9T/output/repository/returnn/tf/network.py:175) ]]
[[optimize/gradients/lstm0_fw/rec/NativeLstm2_grad/tuple/control_dependency_1_accum_grad/FloorMod/ReadVariableOp/_869]]
(1) Invalid argument: You must feed a value for placeholder tensor 'extern_data/placeholders/batch_dim' with dtype int32
[[node extern_data/placeholders/batch_dim (defined at work/asr4/rossenbach/sisyphus_work_folders/tts_asr_2021_work/i6_core/tools/git/CloneGitRepositoryJob.3OlznTemDM9T/output/repository/returnn/tf/network.py:175) ]]
This is because we now added the batch_dim
as mandatory feed, but were lacking a test case that covers this.
Of course, in the case of the tf.data
pipeline, batch_dim
would not be feeded (be a tf.placeholder
) but also come directly from the pipeline.
For anyone working on this: A problem why this issue was not noticed earlier is that we also lack a test case for this. So we should first also make a (very simple) test case, which fails currently due to this issue, and then fix it.
I noticed, we actually do have test cases already for this, for example test_engine_train_new_dataset_pipeline
. So why was this unnoticed? Maybe because this specific test case never makes use of the batch dim, so it does not trigger this. So then we maybe can add a copy of test_engine_train_new_dataset_pipeline
where we put a slightly more complex network where the batch dim is relevant (but still as trivial as possible).