ml-4m icon indicating copy to clipboard operation
ml-4m copied to clipboard

What’s the best way to use Color palette and another image to condition outputs?

Open amaarora opened this issue 1 year ago • 2 comments

Thank you authors for open sourcing your amazing work.

What would be the best way to use Color palette for image generation and image retrieval please?

amaarora avatar Jun 29 '24 18:06 amaarora

This is what I tried so far as color palette input and used text tokenizer.

caption = 'a nice house with outdoor view [S_1]'
bboxes = '[S_1] v0=0 v1=250 v2=420 v3=100 potted plant ' \
         'v0=700 v1=720 v2=740 v3=850 bottle [S_2]'
color_palette = '[S_2] color = 2 R = 79 G = 158 B = 143 R = 29 G = 107 B = 137 [S_3]'

Next, we use batched_sample

batched_sample = {}

# Initialize target modalities
for target_mod, ntoks in zip(target_domains, tokens_per_target):
    batched_sample = init_empty_target_modality(batched_sample, MODALITY_INFO, target_mod, 1, ntoks, device)
    
batched_sample = custom_text(
    batched_sample, input_text=caption, eos_token='[EOS]', 
    key='caption', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
    batched_sample, input_text=bboxes, eos_token='[EOS]', 
    key='det', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
    batched_sample, input_text=color_palette, eos_token='[EOS]', 
    key='color_palette', device=device, text_tokenizer=text_tok
)

And finally create out_dict and dec_dict. But dec_dict fails and gives me an error.

out_dict = sampler.generate(
    batched_sample, schedule, text_tokenizer=text_tok, 
    verbose=True, seed=0,
    top_p=top_p, top_k=top_k,
)
dec_dict = decode_dict(
    out_dict, toks, text_tok, 
    image_size=224, patch_size=16,
    decoding_steps=1
)

amaarora avatar Jun 30 '24 00:06 amaarora

Just want to confirm the best way to extract using caption and color palette as well please for retreival.

This is what I have so far based on the input notebook.


# Generation configurations
cond_domains = ["caption", "color_palette"]
target_domains = ["tok_dinov2_global"]
tokens_per_target = [16]
generation_config = {
    "autoregression_schemes": ["roar"],
    "decoding_steps": [1],
    "token_decoding_schedules": ["linear"],
    "temps": [2.0],
    "temp_schedules": ["onex:0.5:0.5"],
    "cfg_scales": [1.0],
    "cfg_schedules": ["constant"],
    "cfg_grow_conditioning": True,
}
top_p, top_k = 0.8, 0.0

schedule = build_chained_generation_schedules(
    cond_domains=cond_domains,
    target_domains=target_domains,
    tokens_per_target=tokens_per_target,
    **generation_config,
)

fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE)
sampler = GenerationSampler(fm_model)

for target_mod, ntoks in zip(target_domains, tokens_per_target):
    batched_sample = init_empty_target_modality(
        batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE
    )

batched_sample = custom_text(
    batched_sample,
    input_text=caption,
    eos_token="[EOS]",
    key="caption",
    device=DEVICE,
    text_tokenizer=text_tokenizer,
)
batched_sample = custom_text(
    batched_sample,
    input_text=color_palette,
    eos_token="[EOS]",
    key="color_palette",
    device=DEVICE,
    text_tokenizer=text_tokenizer,
)

out_dict = sampler.generate(
    batched_sample,
    schedule,
    text_tokenizer=text_tokenizer,
    verbose=True,
    seed=0,
    top_p=top_p,
    top_k=top_k,
)

with torch.no_grad():
    dec_dict = decode_dict(
        out_dict,
        {"tok_dinov2_global": vqvae.to(DEVICE)},
        text_tokenizer,
        image_size=IMG_SIZE,
        patch_size=16,
        decoding_steps=1,
    )

combined_features = dec_dict["tok_dinov2_global"]

amaarora avatar Jun 30 '24 04:06 amaarora