taming-transformers
taming-transformers copied to clipboard
how do we guarantee a reasonable conditional generation when training transformer?
Hello, Thanks for the awesome code. I meet a problem when trying to understand how the transformer learns in the third stage.
In the segmentation and depth-conditioned generation tasks, we train the transformer using F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
in https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/cond_transformer.py#L286, where target and logits are defined in https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/cond_transformer.py#L90-L104.
So, we learn z_indices from cz_indices = torch.cat((c_indices, a_indices), dim=1)
. I just wonder why the network will not collapse to just momorize the z_indices?
I find in the colab notebook that even when randomly choosing the z_indices, the model could still behave well with a proper c_indices (in that case, c_indices comes from the segmentation mask). But I am just curious how the model learns under a relatively weak supervision?
@lukun199 Hi, have you figured that out? I am also curious about this part of code.
The GPT tries to predict input[i] based on input[:i-1] without looking at input[I]. During test it will try to predict ith output by using condition and prediction[:i-1].