OOTDiffusion
OOTDiffusion copied to clipboard
代码中有地方不明白,为什么这么处理
为什么: https://github.com/levihsu/OOTDiffusion/blob/main/ootd/pipelines_ootd/pipeline_ootd.py#L586
if ...
...
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
if do_classifier_free_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds]) # here. why not torch.cat([prompt_embeds, uncond_input ])
这里在做classifier free guidance的时候,不是拼一个null_embeds,而是仍用prompt_embeds?况且其上还拿到了uncond_input 来着
我们这里没有对prompt embeds做classifier-free guidance,而是对garment latents做,可以看一下这个函数https://github.com/levihsu/OOTDiffusion/blob/1347a6b88118d7c508966d6e59dd641a236ded86/ootd/pipelines_ootd/pipeline_ootd.py#L701 论文这周内就会放出来,具体细节可以稍等一下哈