T2F
T2F copied to clipboard
How do we give input to the model, and where are the processed images stored?
I have trained the model, but now I need to test it. I took the demo.py as inspiration for the new demo, and am trying to give my custom caption as input. However I do not know how to do so.
# load the model for the demo
gen = th.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(th.load("GAN_GEN_SHADOW_8.pth", map_location=str(device)))
How do I change the above code for making my trained model work?
hello @MeteoRex11 did you solve this problem ? I am facing the same problem, please help me
You would need to store text in torch file format, with text embeddings already in them. This is a very tedious process, hence I did not make a script for it
I got it working after few cpu specific modifications for my laptop, using below:
Referred from: An IPYNB Note: Depending upon PT version, few minor changes might be required
After importing your condition augmentor, text encoder and gan generator...
import pickle
file_name = 'path/to/your/.pkl'
obj = None
# obj = pickle.loads(open(file_name, 'rb').read())
with open(file_name, "rb") as pick:
obj = pickle.loads(pick.read())
max_caption_len = 100
in_str = input('Enter your caption : ')
in_str_tok = in_str.replace('_', ' ').split()
in_ind_list = [obj['rev_vocab'][in_str.strip()] for in_str in in_str_tok if in_str.strip() in obj['rev_vocab']]
caption = in_ind_list
full_str = []
for ind in caption:
full_str.append(obj['vocab'][ind])
str_proc = filter('<pad>'.__ne__, full_str)
if len(caption) < max_caption_len:
while len(caption) != max_caption_len:
caption.append(obj['rev_vocab']["<pad>"])
elif len(caption) > max_caption_len:
caption = caption[: max_caption_len]
fixed_captions = th.tensor([caption], dtype=th.long)
print("Text initialized!")
fixed_embeddings = text_encoder(fixed_captions)
fixed_embeddings = th.from_numpy(fixed_embeddings.detach().cpu().data.numpy()).to(device)
fixed_c_not_hats, mus, _ = condition_augmenter(fixed_embeddings)
fixed_noise = th.zeros(len(fixed_captions), c_pro_gan.latent_size - fixed_c_not_hats.shape[-1]).to(device)
fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)
print("Gan input prepared")
And then...
import matplotlib.pyplot as plt
%matplotlib inline
create_grid(
samples=c_pro_gan.gen(
fixed_gan_input,
4,
1.0
),
scale_factor=1,
img_file='output.png')
img = plt.imread('output.png')
plt.figure()
plt.imshow(img)