T2F icon indicating copy to clipboard operation
T2F copied to clipboard

How do we give input to the model, and where are the processed images stored?

Open harbarex opened this issue 6 years ago • 3 comments

I have trained the model, but now I need to test it. I took the demo.py as inspiration for the new demo, and am trying to give my custom caption as input. However I do not know how to do so.

# load the model for the demo
gen = th.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(th.load("GAN_GEN_SHADOW_8.pth", map_location=str(device)))

How do I change the above code for making my trained model work?

harbarex avatar May 06 '19 22:05 harbarex

hello @MeteoRex11 did you solve this problem ? I am facing the same problem, please help me

adorsho avatar Oct 21 '19 03:10 adorsho

You would need to store text in torch file format, with text embeddings already in them. This is a very tedious process, hence I did not make a script for it

harbarex avatar Oct 25 '19 14:10 harbarex

I got it working after few cpu specific modifications for my laptop, using below:

Referred from: An IPYNB Note: Depending upon PT version, few minor changes might be required

After importing your condition augmentor, text encoder and gan generator...

import pickle
file_name = 'path/to/your/.pkl'
obj = None
# obj = pickle.loads(open(file_name, 'rb').read())
with open(file_name, "rb") as pick:
    obj = pickle.loads(pick.read())

max_caption_len = 100
in_str = input('Enter your caption : ')
in_str_tok = in_str.replace('_', ' ').split()
in_ind_list = [obj['rev_vocab'][in_str.strip()] for in_str in in_str_tok if in_str.strip() in obj['rev_vocab']]
caption = in_ind_list
full_str = []
for ind in caption:
    full_str.append(obj['vocab'][ind])
str_proc = filter('<pad>'.__ne__, full_str)
    
if len(caption) < max_caption_len:
    while len(caption) != max_caption_len:
        caption.append(obj['rev_vocab']["<pad>"])

elif len(caption) > max_caption_len:
    caption = caption[: max_caption_len]

fixed_captions = th.tensor([caption], dtype=th.long)
print("Text initialized!")

fixed_embeddings = text_encoder(fixed_captions)
fixed_embeddings = th.from_numpy(fixed_embeddings.detach().cpu().data.numpy()).to(device)
fixed_c_not_hats, mus, _ = condition_augmenter(fixed_embeddings)
fixed_noise = th.zeros(len(fixed_captions), c_pro_gan.latent_size - fixed_c_not_hats.shape[-1]).to(device)
fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)
print("Gan input prepared")

And then...

import matplotlib.pyplot as plt
%matplotlib inline
create_grid(
    samples=c_pro_gan.gen(
        fixed_gan_input,
        4,
        1.0
    ),
    scale_factor=1,
    img_file='output.png')

img = plt.imread('output.png')
plt.figure()
plt.imshow(img)

shiv6891 avatar Jul 31 '20 07:07 shiv6891