glow
glow copied to clipboard
Training CelebA with attributes
I am trying to train celebA with attributes, but it looks like training with attributes is not well supported in the code. I made the appropriate changes (per the comments) to function parse_tfrecord_tf
in get_data.py
to retrieve and return attr
, and also modified train.py
function infer
like so:
if hps.direct_iterator:
# replace with x, y, attr if you're getting CelebA attributes...
x, y, attr = sess.run(iterator)
else:
x, y, attr = iterator()
It appears these changes are not sufficient, as I get errors in model.py's make_batch
function. I fix that by retrieving attr
from sess.run(data)
and modifying the function output, but that only kicks the error to f_loss
, so I make what I think are the appropriate extensions to accommodate the attributes, but a new error shows up somewhere else, and with each change I'm making, the more uncertain I am that I'm doing the right thing.
Am I doing something wrong? Can someone advise on how to train with attributes, or is this functionality just not fully supported?
Thanks!
The change I did was to replace the label returned (ie y
) with attr
in parse_tf_record
, and not add it as an extra output. Then the rest of the code like make_batch
, f_loss
etc should still work.
Thanks for responding!
Just so you're aware, the comments in both train.py and get_data.py suggest the addition of the attr
variable as opposed to a replacement of y
. Namely,
"to get CelebA attr, also return attr
" and "replace with x, y, attr if you're getting CelebA attributes, also modify get_data"
, which is where the confusion stemmed from.
I made the changes like you suggested, but I also had to make a change in model.py line 147: Y = tf.placeholder(tf.int32, [None], name='label')
to Y = tf.placeholder(tf.int32, [None, 40], name='attr')
to avoid getting the following error: ValueError: Cannot feed value of shape (256, 40) for Tensor 'input/label:0', which has shape '(?,)'
I wonder if this is an appropriate way to deal with the error, since it obviates the need to make any changes in train.py?
Cheers
The change I did was to replace the label returned (ie
y
) withattr
inparse_tf_record
, and not add it as an extra output. Then the rest of the code likemake_batch
,f_loss
etc should still work.
It works for me!!