ar-vae
ar-vae copied to clipboard
main() on train_image_vae.py
Hello Dr.Pati,
I am trying to run the train_image_vae.py script on Colab (where I have copied over all the necessary components). I am running into the following error when calling the main() function without arguments:
TypeError: main() missing 13 required positional arguments: 'dataset_type', 'batch_size', 'num_epochs', 'lr', 'beta', 'capacity', 'gamma', 'delta', 'dec_dist', 'train', 'log', 'rand', and 'reg_type'
I tried passing the required parameters (based on the values specified in the class ImageVAETrainer(Trainer) as follows:
main('EchoLV', 2, 10, 1e-04, 4.0, 0.0, 10.0, 1.0, 'bernoulli', 'train', 'log', 'rand', 'volume')
but it seems to be reading in the 'volume' argument letter-by-letter. Here is the error I am getting:
in main(dataset_type, batch_size, num_epochs, lr, beta, capacity, gamma, delta, dec_dist, train, log, rand, reg_type)
35 reg_dim = []
36 for r in reg_type:
---> 37 reg_dim.append(attr_dict[r])
38 else:
39 reg_dim = [0]
KeyError: 'v'
I tried using ['volume'] but ran into: TypeError: unhashable type: 'list'
What am I missing? Any help would be appreciated. Thank you!
It's been a while since I have looked at this.
There seems to be a bug in the code. You can try changing the lines below (73--86 in main()):
if len(reg_type) != 0:
if len(reg_type) == 1:
if reg_type[0] == 'all':
reg_dim = []
for r in attr_dict.keys():
if r == 'digit_identity' or r == 'color':
continue
reg_dim.append(attr_dict[r])
else:
reg_dim = [attr_dict[reg_type]]
else:
reg_dim = []
for r in reg_type:
reg_dim.append(attr_dict[r])
to
# these should convert your string `volume` to [`volume`]
if isinstance(reg_type, str):
reg_type = [reg_type]
# we do the same as before
if len(reg_type) != 0:
if len(reg_type) == 1:
if reg_type[0] == 'all':
reg_dim = []
for r in attr_dict.keys():
if r == 'digit_identity' or r == 'color':
continue
reg_dim.append(attr_dict[r])
else:
reg_dim = [attr_dict[reg_type]]
else:
reg_dim = []
for r in reg_type:
reg_dim.append(attr_dict[r])
However, the code was originally written to work with the mnist / dpsrites datasets only. However, it seems like you are using a custom dataset (EchoLV). So, I am not sure it would work even if you fix this. You might have to do some additional refactoring. E.g, you might need to define a custom attr_dict