taming-transformers icon indicating copy to clipboard operation
taming-transformers copied to clipboard

how do we guarantee a reasonable conditional generation when training transformer?

Open lukun199 opened this issue 3 years ago • 2 comments

Hello, Thanks for the awesome code. I meet a problem when trying to understand how the transformer learns in the third stage.

In the segmentation and depth-conditioned generation tasks, we train the transformer using F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) in https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/cond_transformer.py#L286, where target and logits are defined in https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/cond_transformer.py#L90-L104. So, we learn z_indices from cz_indices = torch.cat((c_indices, a_indices), dim=1). I just wonder why the network will not collapse to just momorize the z_indices?

I find in the colab notebook that even when randomly choosing the z_indices, the model could still behave well with a proper c_indices (in that case, c_indices comes from the segmentation mask). But I am just curious how the model learns under a relatively weak supervision?

lukun199 avatar Sep 02 '21 12:09 lukun199

@lukun199 Hi, have you figured that out? I am also curious about this part of code.

IceClear avatar Mar 29 '22 16:03 IceClear

The GPT tries to predict input[i] based on input[:i-1] without looking at input[I]. During test it will try to predict ith output by using condition and prediction[:i-1].

ZhuXiyue avatar Dec 11 '22 06:12 ZhuXiyue