uncertainty-baselines
uncertainty-baselines copied to clipboard
Unsupported data type for TPU String
Getting this error when running baselines/cifar10/deterministic.py
tensorflow/compiler/jit/xla_compilation_cache.cc:334] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.I1017 19:47:33.579260 140642886925376 deterministic.py:260] [deterministic.py:260] Model input shape: (None, 32, 32, 3)
I1017 19:47:33.579456 140642886925376 deterministic.py:261] [deterministic.py:261] Model output shape: (None, 10)
I1017 19:47:33.581461 140642886925376 deterministic.py:262] [deterministic.py:262] Model number of weights: 36497146
I1017 19:47:34.059113 140642886925376 deterministic.py:457] [deterministic.py:457] Starting to run epoch: 0
I1017 19:47:34.920352 140642886925376 tpu.py:1376] [tpu.py:1376] TPU has inputs with dynamic shapes: [<tf.Tensor 'while/Const:0' shape=() dtype=int32>, <tf.Tensor 'while/cond_8/Identity:0' shape=(None,) dtype=int64>, <tf.Tensor 'while/cond_8/Identity_1:0' shape=(None, 1) dtype=int64>, <tf.Tensor 'while/cond_8/Identity_2:0' shape=(None, 32, 32, 3) dtype=float32>, <tf.Tensor 'while/cond_8/Identity_3:0' shape=(None,) dtype=string>, <tf.Tensor 'while/cond_8/I
dentity_4:0' shape=(None,) dtype=float32>]
I1017 19:47:40.580166 140642886925376 tpu.py:1376] [tpu.py:1376] TPU has inputs with dynamic shapes: [<tf.Tensor 'while/Const:0' shape=() dtype=int32>, <
tf.Tensor 'while/cond_8/Identity:0' shape=(None,) dtype=int64>, <tf.Tensor 'while/cond_8/Identity_1:0' shape=(None, 1) dtype=int64>, <tf.Tensor 'while/co
nd_8/Identity_2:0' shape=(None, 32, 32, 3) dtype=float32>, <tf.Tensor 'while/cond_8/Identity_3:0' shape=(None,) dtype=string>, <tf.Tensor 'while/cond_8/I
dentity_4:0' shape=(None,) dtype=float32>]
2021-10-17 19:47:44.299236: I tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc:263] Subgraph fingerprint:1696516458553054063
Traceback (most recent call last):
File "baselines/cifar/deterministic.py", line 557, in
My train_iterator returns the following keys.
'id', '_enumerate_added_per_step_id', 'element_id', 'features', 'labels']
Now in the code it should be extracting the features and labels from this iterator. But the id
key has a Per-replicas that has type 'string' Maybe this is why?
Update. I had to change around the code a bit so that the output of the train_iterator and validation_iterator do not contain any strings. Now it is working perfectly!
Hey! Thanks for catching this, for some reason this was working fine for us internally, hopefully this PR will fix it externally.