glow icon indicating copy to clipboard operation
glow copied to clipboard

Training CelebA with attributes

Open nuges01 opened this issue 6 years ago • 3 comments

I am trying to train celebA with attributes, but it looks like training with attributes is not well supported in the code. I made the appropriate changes (per the comments) to function parse_tfrecord_tf in get_data.py to retrieve and return attr, and also modified train.py function infer like so:

        if hps.direct_iterator:
            # replace with x, y, attr if you're getting CelebA attributes...
            x, y, attr = sess.run(iterator)
        else:
            x, y, attr = iterator()

It appears these changes are not sufficient, as I get errors in model.py's make_batch function. I fix that by retrieving attr from sess.run(data) and modifying the function output, but that only kicks the error to f_loss, so I make what I think are the appropriate extensions to accommodate the attributes, but a new error shows up somewhere else, and with each change I'm making, the more uncertain I am that I'm doing the right thing.

Am I doing something wrong? Can someone advise on how to train with attributes, or is this functionality just not fully supported?

Thanks!

nuges01 avatar Jul 26 '18 03:07 nuges01

The change I did was to replace the label returned (ie y) with attr in parse_tf_record, and not add it as an extra output. Then the rest of the code like make_batch, f_loss etc should still work.

prafullasd avatar Jul 27 '18 01:07 prafullasd

Thanks for responding!

Just so you're aware, the comments in both train.py and get_data.py suggest the addition of the attr variable as opposed to a replacement of y. Namely, "to get CelebA attr, also return attr" and "replace with x, y, attr if you're getting CelebA attributes, also modify get_data", which is where the confusion stemmed from.

I made the changes like you suggested, but I also had to make a change in model.py line 147: Y = tf.placeholder(tf.int32, [None], name='label') to Y = tf.placeholder(tf.int32, [None, 40], name='attr') to avoid getting the following error: ValueError: Cannot feed value of shape (256, 40) for Tensor 'input/label:0', which has shape '(?,)'

I wonder if this is an appropriate way to deal with the error, since it obviates the need to make any changes in train.py?

Cheers

nuges01 avatar Jul 27 '18 06:07 nuges01

The change I did was to replace the label returned (ie y) with attr in parse_tf_record, and not add it as an extra output. Then the rest of the code like make_batch, f_loss etc should still work.

It works for me!!

haolsun avatar Oct 25 '18 16:10 haolsun