improved-diffusion
improved-diffusion copied to clipboard
When sampling a model trained with class_cond=True, do we need to give it the imagenet grountruth labels?
As in the title, I am wandering if I need to pass the class label when sampling a trained super resolution network.
Isn't it wrong to use the data groundtruth labels during sampling? How could I have them in a real case without a classifier?
Additional question: does class conditioning improve the super resolution model performance?
Many thanks, Stefano
if you read the manual there is the answer
"If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names)."
You have to rename your training set accordingly.
Also you may change the NUM_CLASSES variable in the script_util.py
if you read the manual there is the answer
"If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names)."
You have to rename your training set accordingly.
Also you may change the
NUM_CLASSESvariable in thescript_util.py
You mean renaming the images to mylabel1_XXX.jpg or just creating seperate text files?
Yes exactly rename your files mylabelX_xxx.jpg. Please consider that the data loader is recursive. That means you can put each label in a different subfolder, when you like to have it sorted.
Thanks @choROPeNt for your answer! Actually my question was for the sampling stage, At the beginning I thought that we needed to provide the labels solely during training, and that during sampling that part would have been replaced by the classifier. I was wrong: you need to condition (use the groundtruth labels) even during sampling. This is because the conditioning and guidance are two completely separate things.
However I am still thinking which kind of improvement can these labels provide during sampling, and I am also wondering what could be the purporse of a Super resolution network requiring the groundtruth labels. I think an easy answer to this is the fact that the super resolution network is here used only for upsampling the generated images.
I hope my message is somehow clear.
yes thats true, even on the sampling stage you have to submit label classes accordingly to your dataset.
in image_sample.py the following lines of code do that
if args.class_cond:
classes = th.randint(
low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
)
model_kwargs["y"] = classes
you can modify this to
classes = th.IntTensor([cls] * args.batch_size).to(dist_util.dev())
when you like to have a specific class sampled.
Yet I dont know the exact difference of guidance and conditioning. But maybe guided diffusion are good answers.
Hi, i have 1000 classes, but their name is a string, not a number.How can i input the class when i run the sample.py?It will throw a errror if i directly input the string name. classes = th.IntTensor([cls] * args.batch_size).to(dist_util.dev()) ValueError: too many dimensions 'str'
you can try something like this.
Create a translation dict with all classes and assign them to an integer value
my_dict = {"cat": 0, "dog": 1}
your data loader should return a list of strings like that
my_list = ["cat", "dog","cat","cat"]
now convert the list of strings to a list of integers by given dict:
matched_values = [my_dict[item] for item in my_list if item in my_dict]
for your example you can write:
classes_dict = {"cat": 0, "dog": 1}
cls = ["cat", "dog","cat,"cat","dog"]
classes = th.IntTensor([classes_dict[item] for item in cls if item in classes_dict] * args.batch_size).to(dist_util.dev())
print("Matched values:", matched_values)