ml-4m
ml-4m copied to clipboard
What’s the best way to use Color palette and another image to condition outputs?
Thank you authors for open sourcing your amazing work.
What would be the best way to use Color palette for image generation and image retrieval please?
This is what I tried so far as color palette input and used text tokenizer.
caption = 'a nice house with outdoor view [S_1]'
bboxes = '[S_1] v0=0 v1=250 v2=420 v3=100 potted plant ' \
'v0=700 v1=720 v2=740 v3=850 bottle [S_2]'
color_palette = '[S_2] color = 2 R = 79 G = 158 B = 143 R = 29 G = 107 B = 137 [S_3]'
Next, we use batched_sample
batched_sample = {}
# Initialize target modalities
for target_mod, ntoks in zip(target_domains, tokens_per_target):
batched_sample = init_empty_target_modality(batched_sample, MODALITY_INFO, target_mod, 1, ntoks, device)
batched_sample = custom_text(
batched_sample, input_text=caption, eos_token='[EOS]',
key='caption', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
batched_sample, input_text=bboxes, eos_token='[EOS]',
key='det', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
batched_sample, input_text=color_palette, eos_token='[EOS]',
key='color_palette', device=device, text_tokenizer=text_tok
)
And finally create out_dict and dec_dict. But dec_dict fails and gives me an error.
out_dict = sampler.generate(
batched_sample, schedule, text_tokenizer=text_tok,
verbose=True, seed=0,
top_p=top_p, top_k=top_k,
)
dec_dict = decode_dict(
out_dict, toks, text_tok,
image_size=224, patch_size=16,
decoding_steps=1
)
Just want to confirm the best way to extract using caption and color palette as well please for retreival.
This is what I have so far based on the input notebook.
# Generation configurations
cond_domains = ["caption", "color_palette"]
target_domains = ["tok_dinov2_global"]
tokens_per_target = [16]
generation_config = {
"autoregression_schemes": ["roar"],
"decoding_steps": [1],
"token_decoding_schedules": ["linear"],
"temps": [2.0],
"temp_schedules": ["onex:0.5:0.5"],
"cfg_scales": [1.0],
"cfg_schedules": ["constant"],
"cfg_grow_conditioning": True,
}
top_p, top_k = 0.8, 0.0
schedule = build_chained_generation_schedules(
cond_domains=cond_domains,
target_domains=target_domains,
tokens_per_target=tokens_per_target,
**generation_config,
)
fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE)
sampler = GenerationSampler(fm_model)
for target_mod, ntoks in zip(target_domains, tokens_per_target):
batched_sample = init_empty_target_modality(
batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE
)
batched_sample = custom_text(
batched_sample,
input_text=caption,
eos_token="[EOS]",
key="caption",
device=DEVICE,
text_tokenizer=text_tokenizer,
)
batched_sample = custom_text(
batched_sample,
input_text=color_palette,
eos_token="[EOS]",
key="color_palette",
device=DEVICE,
text_tokenizer=text_tokenizer,
)
out_dict = sampler.generate(
batched_sample,
schedule,
text_tokenizer=text_tokenizer,
verbose=True,
seed=0,
top_p=top_p,
top_k=top_k,
)
with torch.no_grad():
dec_dict = decode_dict(
out_dict,
{"tok_dinov2_global": vqvae.to(DEVICE)},
text_tokenizer,
image_size=IMG_SIZE,
patch_size=16,
decoding_steps=1,
)
combined_features = dec_dict["tok_dinov2_global"]